diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index 4d5b9ffefcd..cca18b91675 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -15,10 +15,14 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.mapper.FieldAliasMapper; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; import java.io.IOException; import java.util.Arrays; @@ -46,6 +50,7 @@ public class Classification implements DataFrameAnalysis { public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); public static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); + public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors"); private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1"; @@ -59,6 +64,7 @@ public class Classification implements DataFrameAnalysis { */ public static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30; + @SuppressWarnings("unchecked") private static ConstructingObjectParser createParser(boolean lenient) { ConstructingObjectParser parser = new ConstructingObjectParser<>( NAME.getPreferredName(), @@ -70,7 +76,8 @@ public class Classification implements DataFrameAnalysis { (ClassAssignmentObjective) a[8], (Integer) a[9], (Double) a[10], - (Long) a[11])); + (Long) a[11], + (List) a[12])); parser.declareString(constructorArg(), DEPENDENT_VARIABLE); BoostedTreeParams.declareFields(parser); parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME); @@ -78,6 +85,12 @@ public class Classification implements DataFrameAnalysis { parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES); parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT); parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED); + parser.declareNamedObjects(optionalConstructorArg(), + (p, c, n) -> lenient ? + p.namedObject(LenientlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)) : + p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)), + (classification) -> {/*TODO should we throw if this is not set?*/}, + FEATURE_PROCESSORS); return parser; } @@ -119,6 +132,7 @@ public class Classification implements DataFrameAnalysis { private final int numTopClasses; private final double trainingPercent; private final long randomizeSeed; + private final List featureProcessors; public Classification(String dependentVariable, BoostedTreeParams boostedTreeParams, @@ -126,7 +140,8 @@ public class Classification implements DataFrameAnalysis { @Nullable ClassAssignmentObjective classAssignmentObjective, @Nullable Integer numTopClasses, @Nullable Double trainingPercent, - @Nullable Long randomizeSeed) { + @Nullable Long randomizeSeed, + @Nullable List featureProcessors) { if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) { throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName()); } @@ -141,10 +156,11 @@ public class Classification implements DataFrameAnalysis { this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses; this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent; this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed; + this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors); } public Classification(String dependentVariable) { - this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null); + this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null); } public Classification(StreamInput in) throws IOException { @@ -163,6 +179,11 @@ public class Classification implements DataFrameAnalysis { } else { randomizeSeed = Randomness.get().nextLong(); } + if (in.getVersion().onOrAfter(Version.V_7_10_0)) { + featureProcessors = Collections.unmodifiableList(in.readNamedWriteableList(PreProcessor.class)); + } else { + featureProcessors = Collections.emptyList(); + } } public String getDependentVariable() { @@ -193,6 +214,10 @@ public class Classification implements DataFrameAnalysis { return randomizeSeed; } + public List getFeatureProcessors() { + return featureProcessors; + } + @Override public String getWriteableName() { return NAME.getPreferredName(); @@ -211,6 +236,9 @@ public class Classification implements DataFrameAnalysis { if (out.getVersion().onOrAfter(Version.V_7_6_0)) { out.writeOptionalLong(randomizeSeed); } + if (out.getVersion().onOrAfter(Version.V_7_10_0)) { + out.writeNamedWriteableList(featureProcessors); + } } @Override @@ -229,6 +257,9 @@ public class Classification implements DataFrameAnalysis { if (version.onOrAfter(Version.V_7_6_0)) { builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed); } + if (featureProcessors.isEmpty() == false) { + NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors); + } builder.endObject(); return builder; } @@ -249,6 +280,10 @@ public class Classification implements DataFrameAnalysis { } params.put(NUM_CLASSES, fieldInfo.getCardinality(dependentVariable)); params.put(TRAINING_PERCENT.getPreferredName(), trainingPercent); + if (featureProcessors.isEmpty() == false) { + params.put(FEATURE_PROCESSORS.getPreferredName(), + featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList())); + } return params; } @@ -390,6 +425,7 @@ public class Classification implements DataFrameAnalysis { && Objects.equals(predictionFieldName, that.predictionFieldName) && Objects.equals(classAssignmentObjective, that.classAssignmentObjective) && Objects.equals(numTopClasses, that.numTopClasses) + && Objects.equals(featureProcessors, that.featureProcessors) && trainingPercent == that.trainingPercent && randomizeSeed == that.randomizeSeed; } @@ -397,7 +433,7 @@ public class Classification implements DataFrameAnalysis { @Override public int hashCode() { return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, classAssignmentObjective, - numTopClasses, trainingPercent, randomizeSeed); + numTopClasses, trainingPercent, randomizeSeed, featureProcessors); } public enum ClassAssignmentObjective { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index b4b06187fbe..dbc70dc219b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -15,9 +15,13 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.mapper.NumberFieldMapper; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; import java.io.IOException; import java.util.Arrays; @@ -28,6 +32,7 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.stream.Collectors; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -42,12 +47,14 @@ public class Regression implements DataFrameAnalysis { public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); public static final ParseField LOSS_FUNCTION = new ParseField("loss_function"); public static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter"); + public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors"); private static final String STATE_DOC_ID_SUFFIX = "_regression_state#1"; private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + @SuppressWarnings("unchecked") private static ConstructingObjectParser createParser(boolean lenient) { ConstructingObjectParser parser = new ConstructingObjectParser<>( NAME.getPreferredName(), @@ -59,7 +66,8 @@ public class Regression implements DataFrameAnalysis { (Double) a[8], (Long) a[9], (LossFunction) a[10], - (Double) a[11])); + (Double) a[11], + (List) a[12])); parser.declareString(constructorArg(), DEPENDENT_VARIABLE); BoostedTreeParams.declareFields(parser); parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME); @@ -67,6 +75,12 @@ public class Regression implements DataFrameAnalysis { parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED); parser.declareString(optionalConstructorArg(), LossFunction::fromString, LOSS_FUNCTION); parser.declareDouble(optionalConstructorArg(), LOSS_FUNCTION_PARAMETER); + parser.declareNamedObjects(optionalConstructorArg(), + (p, c, n) -> lenient ? + p.namedObject(LenientlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)) : + p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)), + (regression) -> {/*TODO should we throw if this is not set?*/}, + FEATURE_PROCESSORS); return parser; } @@ -90,6 +104,7 @@ public class Regression implements DataFrameAnalysis { private final long randomizeSeed; private final LossFunction lossFunction; private final Double lossFunctionParameter; + private final List featureProcessors; public Regression(String dependentVariable, BoostedTreeParams boostedTreeParams, @@ -97,7 +112,8 @@ public class Regression implements DataFrameAnalysis { @Nullable Double trainingPercent, @Nullable Long randomizeSeed, @Nullable LossFunction lossFunction, - @Nullable Double lossFunctionParameter) { + @Nullable Double lossFunctionParameter, + @Nullable List featureProcessors) { if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) { throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName()); } @@ -112,10 +128,11 @@ public class Regression implements DataFrameAnalysis { throw ExceptionsHelper.badRequestException("[{}] must be a positive double", LOSS_FUNCTION_PARAMETER.getPreferredName()); } this.lossFunctionParameter = lossFunctionParameter; + this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors); } public Regression(String dependentVariable) { - this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null); + this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null); } public Regression(StreamInput in) throws IOException { @@ -136,6 +153,11 @@ public class Regression implements DataFrameAnalysis { lossFunction = LossFunction.MSE; lossFunctionParameter = null; } + if (in.getVersion().onOrAfter(Version.V_7_10_0)) { + featureProcessors = Collections.unmodifiableList(in.readNamedWriteableList(PreProcessor.class)); + } else { + featureProcessors = Collections.emptyList(); + } } public String getDependentVariable() { @@ -166,6 +188,10 @@ public class Regression implements DataFrameAnalysis { return lossFunctionParameter; } + public List getFeatureProcessors() { + return featureProcessors; + } + @Override public String getWriteableName() { return NAME.getPreferredName(); @@ -184,6 +210,9 @@ public class Regression implements DataFrameAnalysis { out.writeEnum(lossFunction); out.writeOptionalDouble(lossFunctionParameter); } + if (out.getVersion().onOrAfter(Version.V_7_10_0)) { + out.writeNamedWriteableList(featureProcessors); + } } @Override @@ -204,6 +233,9 @@ public class Regression implements DataFrameAnalysis { if (lossFunctionParameter != null) { builder.field(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter); } + if (featureProcessors.isEmpty() == false) { + NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors); + } builder.endObject(); return builder; } @@ -221,6 +253,10 @@ public class Regression implements DataFrameAnalysis { if (lossFunctionParameter != null) { params.put(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter); } + if (featureProcessors.isEmpty() == false) { + params.put(FEATURE_PROCESSORS.getPreferredName(), + featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList())); + } return params; } @@ -304,13 +340,14 @@ public class Regression implements DataFrameAnalysis { && trainingPercent == that.trainingPercent && randomizeSeed == that.randomizeSeed && lossFunction == that.lossFunction + && Objects.equals(featureProcessors, that.featureProcessors) && Objects.equals(lossFunctionParameter, that.lossFunctionParameter); } @Override public int hashCode() { return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction, - lossFunctionParameter); + lossFunctionParameter, featureProcessors); } public enum LossFunction { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 9f225a997db..2ba7f114e8b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -57,23 +57,23 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider { // PreProcessing Lenient namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, OneHotEncoding.NAME, - OneHotEncoding::fromXContentLenient)); + (p, c) -> OneHotEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c))); namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, TargetMeanEncoding.NAME, - TargetMeanEncoding::fromXContentLenient)); + (p, c) -> TargetMeanEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c))); namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, FrequencyEncoding.NAME, - FrequencyEncoding::fromXContentLenient)); + (p, c) -> FrequencyEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c))); namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, CustomWordEmbedding.NAME, - CustomWordEmbedding::fromXContentLenient)); + (p, c) -> CustomWordEmbedding.fromXContentLenient(p))); // PreProcessing Strict namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, OneHotEncoding.NAME, - OneHotEncoding::fromXContentStrict)); + (p, c) -> OneHotEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c))); namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, TargetMeanEncoding.NAME, - TargetMeanEncoding::fromXContentStrict)); + (p, c) -> TargetMeanEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c))); namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, FrequencyEncoding.NAME, - FrequencyEncoding::fromXContentStrict)); + (p, c) -> FrequencyEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c))); namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, CustomWordEmbedding.NAME, - CustomWordEmbedding::fromXContentStrict)); + (p, c) -> CustomWordEmbedding.fromXContentStrict(p))); // Model Lenient namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentLenient)); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java index 9493572b838..24ec52b3650 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java @@ -56,8 +56,8 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable, Acco TRAINED_MODEL); parser.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors, (p, c, n) -> ignoreUnknownFields ? - p.namedObject(LenientlyParsedPreProcessor.class, n, null) : - p.namedObject(StrictlyParsedPreProcessor.class, n, null), + p.namedObject(LenientlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT) : + p.namedObject(StrictlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT), (trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true), PREPROCESSORS); return parser; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java index 8518caf66eb..bac61d9b8ef 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java @@ -50,15 +50,15 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl public static final ParseField EMBEDDING_WEIGHTS = new ParseField("embedding_weights"); public static final ParseField EMBEDDING_QUANT_SCALES = new ParseField("embedding_quant_scales"); - public static final ConstructingObjectParser STRICT_PARSER = createParser(false); - public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); @SuppressWarnings("unchecked") - private static ConstructingObjectParser createParser(boolean lenient) { - ConstructingObjectParser parser = new ConstructingObjectParser<>( + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>( NAME.getPreferredName(), lenient, - a -> new CustomWordEmbedding((short[][])a[0], (byte[][])a[1], (String)a[2], (String)a[3])); + (a, c) -> new CustomWordEmbedding((short[][])a[0], (byte[][])a[1], (String)a[2], (String)a[3])); parser.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> { @@ -123,11 +123,11 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl } public static CustomWordEmbedding fromXContentStrict(XContentParser parser) { - return STRICT_PARSER.apply(parser, null); + return STRICT_PARSER.apply(parser, PreProcessorParseContext.DEFAULT); } public static CustomWordEmbedding fromXContentLenient(XContentParser parser) { - return LENIENT_PARSER.apply(parser, null); + return LENIENT_PARSER.apply(parser, PreProcessorParseContext.DEFAULT); } private static final int CONCAT_LAYER_SIZE = 80; @@ -256,6 +256,11 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl return false; } + @Override + public String getOutputFieldType(String outputField) { + return "dense_vector"; + } + @Override public long ramBytesUsed() { long size = SHALLOW_SIZE; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java index 34f7e8cae9e..dc9b0d6aeb6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -37,15 +38,18 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP public static final ParseField FREQUENCY_MAP = new ParseField("frequency_map"); public static final ParseField CUSTOM = new ParseField("custom"); - public static final ConstructingObjectParser STRICT_PARSER = createParser(false); - public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); @SuppressWarnings("unchecked") - private static ConstructingObjectParser createParser(boolean lenient) { - ConstructingObjectParser parser = new ConstructingObjectParser<>( + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>( NAME.getPreferredName(), lenient, - a -> new FrequencyEncoding((String)a[0], (String)a[1], (Map)a[2], (Boolean)a[3])); + (a, c) -> new FrequencyEncoding((String)a[0], + (String)a[1], + (Map)a[2], + a[3] == null ? c.isCustomByDefault() : (Boolean)a[3])); parser.declareString(ConstructingObjectParser.constructorArg(), FIELD); parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME); parser.declareObject(ConstructingObjectParser.constructorArg(), @@ -55,12 +59,12 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP return parser; } - public static FrequencyEncoding fromXContentStrict(XContentParser parser) { - return STRICT_PARSER.apply(parser, null); + public static FrequencyEncoding fromXContentStrict(XContentParser parser, PreProcessorParseContext context) { + return STRICT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context); } - public static FrequencyEncoding fromXContentLenient(XContentParser parser) { - return LENIENT_PARSER.apply(parser, null); + public static FrequencyEncoding fromXContentLenient(XContentParser parser, PreProcessorParseContext context) { + return LENIENT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context); } private final String field; @@ -117,6 +121,11 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP return custom; } + @Override + public String getOutputFieldType(String outputField) { + return NumberFieldMapper.NumberType.DOUBLE.typeName(); + } + @Override public String getName() { return NAME.getPreferredName(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java index cf86fe986f4..ade6a659f88 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -36,27 +37,29 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars public static final ParseField HOT_MAP = new ParseField("hot_map"); public static final ParseField CUSTOM = new ParseField("custom"); - public static final ConstructingObjectParser STRICT_PARSER = createParser(false); - public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); @SuppressWarnings("unchecked") - private static ConstructingObjectParser createParser(boolean lenient) { - ConstructingObjectParser parser = new ConstructingObjectParser<>( + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>( NAME.getPreferredName(), lenient, - a -> new OneHotEncoding((String)a[0], (Map)a[1], (Boolean)a[2])); + (a, c) -> new OneHotEncoding((String)a[0], + (Map)a[1], + a[2] == null ? c.isCustomByDefault() : (Boolean)a[2])); parser.declareString(ConstructingObjectParser.constructorArg(), FIELD); parser.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.mapStrings(), HOT_MAP); parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM); return parser; } - public static OneHotEncoding fromXContentStrict(XContentParser parser) { - return STRICT_PARSER.apply(parser, null); + public static OneHotEncoding fromXContentStrict(XContentParser parser, PreProcessorParseContext context) { + return STRICT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context); } - public static OneHotEncoding fromXContentLenient(XContentParser parser) { - return LENIENT_PARSER.apply(parser, null); + public static OneHotEncoding fromXContentLenient(XContentParser parser, PreProcessorParseContext context) { + return LENIENT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context); } private final String field; @@ -103,6 +106,11 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars return custom; } + @Override + public String getOutputFieldType(String outputField) { + return NumberFieldMapper.NumberType.INTEGER.typeName(); + } + @Override public String getName() { return NAME.getPreferredName(); @@ -124,8 +132,9 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars if (value == null) { return; } + final String stringValue = value.toString(); hotMap.forEach((val, col) -> { - int encoding = value.toString().equals(val) ? 1 : 0; + int encoding = stringValue.equals(val) ? 1 : 0; fields.put(col, encoding); }); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java index c5605af6295..59666477370 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java @@ -18,6 +18,18 @@ import java.util.Map; */ public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accountable { + class PreProcessorParseContext { + public static final PreProcessorParseContext DEFAULT = new PreProcessorParseContext(false); + final boolean defaultIsCustomValue; + public PreProcessorParseContext(boolean defaultIsCustomValue) { + this.defaultIsCustomValue = defaultIsCustomValue; + } + + public boolean isCustomByDefault() { + return defaultIsCustomValue; + } + } + /** * The expected input fields */ @@ -48,4 +60,6 @@ public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accou */ boolean isCustom(); + String getOutputFieldType(String outputField); + } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java index 8271a84846c..8f3e36fc30d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -37,15 +38,19 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly public static final ParseField DEFAULT_VALUE = new ParseField("default_value"); public static final ParseField CUSTOM = new ParseField("custom"); - public static final ConstructingObjectParser STRICT_PARSER = createParser(false); - public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); @SuppressWarnings("unchecked") - private static ConstructingObjectParser createParser(boolean lenient) { - ConstructingObjectParser parser = new ConstructingObjectParser<>( + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>( NAME.getPreferredName(), lenient, - a -> new TargetMeanEncoding((String)a[0], (String)a[1], (Map)a[2], (Double)a[3], (Boolean)a[4])); + (a, c) -> new TargetMeanEncoding((String)a[0], + (String)a[1], + (Map)a[2], + (Double)a[3], + a[4] == null ? c.isCustomByDefault() : (Boolean)a[4])); parser.declareString(ConstructingObjectParser.constructorArg(), FIELD); parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME); parser.declareObject(ConstructingObjectParser.constructorArg(), @@ -56,12 +61,12 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly return parser; } - public static TargetMeanEncoding fromXContentStrict(XContentParser parser) { - return STRICT_PARSER.apply(parser, null); + public static TargetMeanEncoding fromXContentStrict(XContentParser parser, PreProcessorParseContext context) { + return STRICT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context); } - public static TargetMeanEncoding fromXContentLenient(XContentParser parser) { - return LENIENT_PARSER.apply(parser, null); + public static TargetMeanEncoding fromXContentLenient(XContentParser parser, PreProcessorParseContext context) { + return LENIENT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context); } private final String field; @@ -128,6 +133,11 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly return custom; } + @Override + public String getOutputFieldType(String outputField) { + return NumberFieldMapper.NumberType.DOUBLE.typeName(); + } + @Override public String getName() { return NAME.getPreferredName(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java index 7739acf1ad7..0e95b66dd15 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java @@ -41,7 +41,7 @@ public class InferenceDefinition { (p, c, n) -> p.namedObject(InferenceModel.class, n, null), TRAINED_MODEL); PARSER.declareNamedObjects(InferenceDefinition.Builder::setPreProcessors, - (p, c, n) -> p.namedObject(LenientlyParsedPreProcessor.class, n, null), + (p, c, n) -> p.namedObject(LenientlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT), (trainedModelDefBuilder) -> {}, PREPROCESSORS); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java index 372b9b2e0d5..a1183e5fc1d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java @@ -327,12 +327,14 @@ public final class ReservedFieldNames { Regression.LOSS_FUNCTION_PARAMETER.getPreferredName(), Regression.PREDICTION_FIELD_NAME.getPreferredName(), Regression.TRAINING_PERCENT.getPreferredName(), + Regression.FEATURE_PROCESSORS.getPreferredName(), Classification.NAME.getPreferredName(), Classification.DEPENDENT_VARIABLE.getPreferredName(), Classification.PREDICTION_FIELD_NAME.getPreferredName(), Classification.CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), Classification.NUM_TOP_CLASSES.getPreferredName(), Classification.TRAINING_PERCENT.getPreferredName(), + Classification.FEATURE_PROCESSORS.getPreferredName(), BoostedTreeParams.LAMBDA.getPreferredName(), BoostedTreeParams.GAMMA.getPreferredName(), BoostedTreeParams.ETA.getPreferredName(), diff --git a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/config_index_mappings.json b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/config_index_mappings.json index c880aa0c714..ad05797d15f 100644 --- a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/config_index_mappings.json +++ b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/config_index_mappings.json @@ -34,6 +34,9 @@ "feature_bag_fraction" : { "type" : "double" }, + "feature_processors": { + "enabled": false + }, "gamma" : { "type" : "double" }, @@ -84,6 +87,9 @@ "feature_bag_fraction" : { "type" : "double" }, + "feature_processors": { + "enabled": false + }, "gamma" : { "type" : "double" }, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java index 83ad1f8285d..bc232dee600 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction.Respon import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import java.util.ArrayList; import java.util.Collections; @@ -28,6 +29,7 @@ public class GetDataFrameAnalyticsActionResponseTests extends AbstractWireSerial List namedWriteables = new ArrayList<>(); namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); + namedWriteables.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); return new NamedWriteableRegistry(namedWriteables); } @@ -36,6 +38,7 @@ public class GetDataFrameAnalyticsActionResponseTests extends AbstractWireSerial List namedXContent = new ArrayList<>(); namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); return new NamedXContentRegistry(namedXContent); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java index 63bc77d4711..a68d40c10b4 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.junit.Before; import java.util.ArrayList; @@ -44,6 +45,7 @@ public class PutDataFrameAnalyticsActionRequestTests extends AbstractSerializing List namedWriteables = new ArrayList<>(); namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); + namedWriteables.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); return new NamedWriteableRegistry(namedWriteables); } @@ -52,6 +54,7 @@ public class PutDataFrameAnalyticsActionRequestTests extends AbstractSerializing List namedXContent = new ArrayList<>(); namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); return new NamedXContentRegistry(namedXContent); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java index 7f160781012..a92aceb3baa 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction.Response; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import java.util.ArrayList; import java.util.Collections; @@ -25,6 +26,7 @@ public class PutDataFrameAnalyticsActionResponseTests extends AbstractWireSerial List namedWriteables = new ArrayList<>(); namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); + namedWriteables.addAll(new MlInferenceNamedXContentProvider() .getNamedWriteables()); return new NamedWriteableRegistry(namedWriteables); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java index cc095adb225..dca53212413 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java @@ -41,6 +41,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.dataframe.analyses.RegressionTests; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.junit.Before; @@ -78,6 +79,7 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC List namedWriteables = new ArrayList<>(); namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); + namedWriteables.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); return new NamedWriteableRegistry(namedWriteables); } @@ -86,6 +88,7 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC List namedXContent = new ArrayList<>(); namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); return new NamedXContentRegistry(namedXContent); } @@ -144,14 +147,16 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC bwcRegression.getTrainingPercent(), 42L, bwcRegression.getLossFunction(), - bwcRegression.getLossFunctionParameter()); + bwcRegression.getLossFunctionParameter(), + bwcRegression.getFeatureProcessors()); testAnalysis = new Regression(testRegression.getDependentVariable(), testRegression.getBoostedTreeParams(), testRegression.getPredictionFieldName(), testRegression.getTrainingPercent(), 42L, testRegression.getLossFunction(), - testRegression.getLossFunctionParameter()); + testRegression.getLossFunctionParameter(), + bwcRegression.getFeatureProcessors()); } else { Classification testClassification = (Classification)testInstance.getAnalysis(); Classification bwcClassification = (Classification)bwcSerializedObject.getAnalysis(); @@ -161,14 +166,16 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC bwcClassification.getClassAssignmentObjective(), bwcClassification.getNumTopClasses(), bwcClassification.getTrainingPercent(), - 42L); + 42L, + bwcClassification.getFeatureProcessors()); testAnalysis = new Classification(testClassification.getDependentVariable(), testClassification.getBoostedTreeParams(), testClassification.getPredictionFieldName(), testClassification.getClassAssignmentObjective(), testClassification.getNumTopClasses(), testClassification.getTrainingPercent(), - 42L); + 42L, + testClassification.getFeatureProcessors()); } super.assertOnBWCObject(new DataFrameAnalyticsConfig.Builder(bwcSerializedObject) .setAnalysis(bwcAnalysis) diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index 426018d89c0..69f15134d1a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -8,25 +8,41 @@ package org.elasticsearch.xpack.core.ml.dataframe.analyses; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.Version; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.DeprecationHandler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.index.mapper.BooleanFieldMapper; import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.NumberFieldMapper; +import org.elasticsearch.search.SearchModule; import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsString; @@ -55,6 +71,21 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + public static Classification createRandom() { String dependentVariableName = randomAlphaOfLength(10); BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom(); @@ -65,7 +96,14 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase randomFrom(FrequencyEncodingTests.createRandom(true), + OneHotEncodingTests.createRandom(true), + TargetMeanEncodingTests.createRandom(true))) + .limit(randomIntBetween(0, 5)) + .collect(Collectors.toList())); } public static Classification mutateForVersion(Classification instance, Version version) { @@ -75,7 +113,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.999, randomLong())); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.999, randomLong(), null)); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } public void testConstructor_GivenTrainingPercentIsGreaterThan100() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong())); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong(), null)); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } public void testConstructor_GivenNumTopClassesIsLessThanZero() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong())); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null)); assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]")); } public void testConstructor_GivenNumTopClassesIsGreaterThan1000() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong())); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null)); assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]")); } public void testGetPredictionFieldName() { - Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong()); + Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null); assertThat(classification.getPredictionFieldName(), equalTo("result")); - classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong(), null); assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction")); } public void testClassAssignmentObjective() { Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", - Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong()); + Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong(), null); assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY)); classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", - Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong()); + Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong(), null); assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL)); // class_assignment_objective == null, default applied - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null); assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL)); } public void testGetNumTopClasses() { - Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong()); + Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null); assertThat(classification.getNumTopClasses(), equalTo(7)); // Boundary condition: num_top_classes == 0 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong(), null); assertThat(classification.getNumTopClasses(), equalTo(0)); // Boundary condition: num_top_classes == 1000 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong(), null); assertThat(classification.getNumTopClasses(), equalTo(1000)); // num_top_classes == null, default applied - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong(), null); assertThat(classification.getNumTopClasses(), equalTo(2)); } public void testGetTrainingPercent() { - Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong()); + Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null); assertThat(classification.getTrainingPercent(), equalTo(50.0)); // Boundary condition: training_percent == 1.0 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong(), null); assertThat(classification.getTrainingPercent(), equalTo(1.0)); // Boundary condition: training_percent == 100.0 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong(), null); assertThat(classification.getTrainingPercent(), equalTo(100.0)); // training_percent == null, default applied - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong(), null); assertThat(classification.getTrainingPercent(), equalTo(100.0)); } @@ -233,6 +325,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + public static Regression createRandom() { return createRandom(BoostedTreeParamsTests.createRandom()); } @@ -57,7 +89,14 @@ public class RegressionTests extends AbstractBWCSerializationTestCase randomFrom(FrequencyEncodingTests.createRandom(true), + OneHotEncodingTests.createRandom(true), + TargetMeanEncodingTests.createRandom(true))) + .limit(randomIntBetween(0, 5)) + .collect(Collectors.toList())); } public static Regression mutateForVersion(Regression instance, Version version) { @@ -67,7 +106,8 @@ public class RegressionTests extends AbstractBWCSerializationTestCase new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999, randomLong(), Regression.LossFunction.MSE, null)); + () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999, randomLong(), Regression.LossFunction.MSE, null, null)); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } public void testConstructor_GivenTrainingPercentIsGreaterThan100() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(), Regression.LossFunction.MSE, null)); + () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(), Regression.LossFunction.MSE, null, null)); + assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } public void testConstructor_GivenLossFunctionParameterIsZero() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, 0.0)); + () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, 0.0, null)); assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double")); } public void testConstructor_GivenLossFunctionParameterIsNegative() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, -1.0)); + () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, -1.0, null)); assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double")); } public void testGetPredictionFieldName() { - Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong(), Regression.LossFunction.MSE, 1.0); + Regression regression = new Regression( + "foo", + BOOSTED_TREE_PARAMS, + "result", + 50.0, + randomLong(), + Regression.LossFunction.MSE, + 1.0, + null); assertThat(regression.getPredictionFieldName(), equalTo("result")); - regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong(), Regression.LossFunction.MSE, null); + regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong(), Regression.LossFunction.MSE, null, null); assertThat(regression.getPredictionFieldName(), equalTo("foo_prediction")); } public void testGetTrainingPercent() { - Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong(), Regression.LossFunction.MSE, 1.0); + Regression regression = new Regression("foo", + BOOSTED_TREE_PARAMS, + "result", + 50.0, + randomLong(), + Regression.LossFunction.MSE, + 1.0, + null); assertThat(regression.getTrainingPercent(), equalTo(50.0)); // Boundary condition: training_percent == 1.0 - regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong(), Regression.LossFunction.MSE, null); + regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong(), Regression.LossFunction.MSE, null, null); assertThat(regression.getTrainingPercent(), equalTo(1.0)); // Boundary condition: training_percent == 100.0 - regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, null); + regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, null, null); assertThat(regression.getTrainingPercent(), equalTo(100.0)); // training_percent == null, default applied - regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong(), Regression.LossFunction.MSE, null); + regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong(), Regression.LossFunction.MSE, null, null); assertThat(regression.getTrainingPercent(), equalTo(100.0)); } @@ -165,6 +273,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase params = regression.getParams(null); @@ -182,7 +291,9 @@ public class RegressionTests extends AbstractBWCSerializationTestCase params = regression.getParams(null); - int expectedParamsCount = 4 + (regression.getLossFunctionParameter() == null ? 0 : 1); + int expectedParamsCount = 4 + + (regression.getLossFunctionParameter() == null ? 0 : 1) + + (regression.getFeatureProcessors().isEmpty() ? 0 : 1); assertThat(params.size(), equalTo(expectedParamsCount)); assertThat(params.get("dependent_variable"), equalTo(regression.getDependentVariable())); assertThat(params.get("prediction_field_name"), equalTo(regression.getPredictionFieldName())); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncodingTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncodingTests.java index 1e63aac0869..254ffa6962d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncodingTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncodingTests.java @@ -24,7 +24,9 @@ public class FrequencyEncodingTests extends PreProcessingTests valueMap = new HashMap<>(); for (int i = 0; i < valuesSize; i++) { @@ -41,7 +47,7 @@ public class FrequencyEncodingTests extends PreProcessingTests { @Override protected OneHotEncoding doParseInstance(XContentParser parser) throws IOException { - return lenient ? OneHotEncoding.fromXContentLenient(parser) : OneHotEncoding.fromXContentStrict(parser); + return lenient ? + OneHotEncoding.fromXContentLenient(parser, PreProcessor.PreProcessorParseContext.DEFAULT) : + OneHotEncoding.fromXContentStrict(parser, PreProcessor.PreProcessorParseContext.DEFAULT); } @Override @@ -33,6 +35,10 @@ public class OneHotEncodingTests extends PreProcessingTests { } public static OneHotEncoding createRandom() { + return createRandom(randomBoolean() ? randomBoolean() : null); + } + + public static OneHotEncoding createRandom(Boolean isCustom) { int valuesSize = randomIntBetween(1, 10); Map valueMap = new HashMap<>(); for (int i = 0; i < valuesSize; i++) { @@ -40,7 +46,7 @@ public class OneHotEncodingTests extends PreProcessingTests { } return new OneHotEncoding(randomAlphaOfLength(10), valueMap, - randomBoolean() ? randomBoolean() : null); + isCustom); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncodingTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncodingTests.java index 60765c83e11..9a31da55fc2 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncodingTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncodingTests.java @@ -24,7 +24,9 @@ public class TargetMeanEncodingTests extends PreProcessingTests valueMap = new HashMap<>(); for (int i = 0; i < valuesSize; i++) { @@ -42,7 +49,7 @@ public class TargetMeanEncodingTests extends PreProcessingTests entries = new ArrayList<>(searchModule.getNamedXContents()); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + entries.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + return new NamedXContentRegistry(entries); + } + public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception { initialize("classification_single_numeric_feature_and_mixed_data_set"); String predictedClassField = KEYWORD_FIELD + "_prediction"; @@ -121,6 +138,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { null, null, null, + null, null)); putAnalytics(config); @@ -176,6 +194,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { null, null, null, + null, null)); putAnalytics(config); @@ -268,6 +287,76 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField); } + public void testWithCustomFeatureProcessors() throws Exception { + initialize("classification_with_custom_feature_processors"); + String predictedClassField = KEYWORD_FIELD + "_prediction"; + indexData(sourceIndex, 300, 50, KEYWORD_FIELD); + + DataFrameAnalyticsConfig config = + buildAnalytics(jobId, sourceIndex, destIndex, null, + new Classification( + KEYWORD_FIELD, + BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(), + null, + null, + null, + null, + null, + Arrays.asList( + new OneHotEncoding(TEXT_FIELD, Collections.singletonMap(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom"), true) + ))); + putAnalytics(config); + + assertIsStopped(jobId); + assertProgressIsZero(jobId); + + startAnalytics(jobId); + waitUntilAnalyticsIsStopped(jobId); + + client().admin().indices().refresh(new RefreshRequest(destIndex)); + SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); + for (SearchHit hit : sourceData.getHits()) { + Map destDoc = getDestDoc(config, hit); + Map resultsObject = getFieldValue(destDoc, "ml"); + assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES))); + assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD))); + assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES); + @SuppressWarnings("unchecked") + List> importanceArray = (List>)resultsObject.get("feature_importance"); + assertThat(importanceArray, hasSize(greaterThan(0))); + } + + assertProgressComplete(jobId); + assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); + assertModelStatePersisted(stateDocId()); + assertInferenceModelPersisted(jobId); + assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword"); + assertThatAuditMessagesMatch(jobId, + "Created analytics with analysis type [classification]", + "Estimated memory usage for this analytics to be", + "Starting analytics on node", + "Started analytics", + expectedDestIndexAuditMessage(), + "Started reindexing to destination index [" + destIndex + "]", + "Finished reindexing to destination index [" + destIndex + "]", + "Started loading data", + "Started analyzing", + "Started writing results", + "Finished analysis"); + assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField); + + GetTrainedModelsAction.Response response = client().execute(GetTrainedModelsAction.INSTANCE, + new GetTrainedModelsAction.Request(jobId + "*", true, Collections.emptyList())).actionGet(); + assertThat(response.getResources().results().size(), equalTo(1)); + TrainedModelConfig modelConfig = response.getResources().results().get(0); + modelConfig.ensureParsedDefinition(xContentRegistry()); + assertThat(modelConfig.getModelDefinition().getPreProcessors().size(), greaterThan(0)); + for (int i = 0; i < modelConfig.getModelDefinition().getPreProcessors().size(); i++) { + PreProcessor preProcessor = modelConfig.getModelDefinition().getPreProcessors().get(i); + assertThat(preProcessor.isCustom(), equalTo(i == 0)); + } + } + public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(String jobId, String dependentVariable, List dependentVariableValues, @@ -283,7 +372,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { sourceIndex, destIndex, null, - new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, null, numTopClasses, 50.0, null)); + new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, null, numTopClasses, 50.0, null, null)); putAnalytics(config); assertIsStopped(jobId); @@ -352,7 +441,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { "classification_training_percent_is_50_integer", DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES, "integer"); } - public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsDouble() throws Exception { + public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsDouble() { ElasticsearchStatusException e = expectThrows( ElasticsearchStatusException.class, () -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty( @@ -360,7 +449,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertThat(e.getMessage(), startsWith("invalid types [double] for required field [numerical-field];")); } - public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsText() throws Exception { + public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsText() { ElasticsearchStatusException e = expectThrows( ElasticsearchStatusException.class, () -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty( @@ -549,7 +638,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { .build(); DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null, - new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, null)); + new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, null, null)); putAnalytics(firstJob); String secondJobId = "classification_two_jobs_with_same_randomize_seed_2"; @@ -557,7 +646,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { long randomizeSeed = ((Classification) firstJob.getAnalysis()).getRandomizeSeed(); DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null, - new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, randomizeSeed)); + new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, randomizeSeed, null)); putAnalytics(secondJob); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ExplainDataFrameAnalyticsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ExplainDataFrameAnalyticsIT.java index ded1e0491de..1e1e521652a 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ExplainDataFrameAnalyticsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ExplainDataFrameAnalyticsIT.java @@ -104,6 +104,7 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg 100.0, null, null, + null, null)) .buildForExplain(); @@ -122,6 +123,7 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg 50.0, null, null, + null, null)) .buildForExplain(); @@ -149,6 +151,7 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg 100.0, null, null, + null, null)) .buildForExplain(); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java index 31f70e02547..73e07b8aca6 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java @@ -14,24 +14,36 @@ import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.junit.After; import java.io.IOException; import java.time.Instant; import java.util.Arrays; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; @@ -65,6 +77,15 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { cleanUp(); } + @Override + protected NamedXContentRegistry xContentRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); + List entries = new ArrayList<>(searchModule.getNamedXContents()); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + entries.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + return new NamedXContentRegistry(entries); + } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/59413") public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception { initialize("regression_single_numeric_feature_and_mixed_data_set"); @@ -79,6 +100,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { null, null, null, + null, null) ); putAnalytics(config); @@ -192,7 +214,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { sourceIndex, destIndex, null, - new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null, null, null)); + new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null, null, null, null)); putAnalytics(config); assertIsStopped(jobId); @@ -319,7 +341,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { .build(); DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null, - new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null, null, null)); + new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null, null, null, null)); putAnalytics(firstJob); String secondJobId = "regression_two_jobs_with_same_randomize_seed_2"; @@ -327,7 +349,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { long randomizeSeed = ((Regression) firstJob.getAnalysis()).getRandomizeSeed(); DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null, - new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, randomizeSeed, null, null)); + new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, randomizeSeed, null, null, null)); putAnalytics(secondJob); @@ -388,7 +410,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { sourceIndex, destIndex, null, - new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null, null, null)); + new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null, null, null, null)); putAnalytics(config); assertIsStopped(jobId); @@ -415,6 +437,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { null, null, null, + null, null) ); putAnalytics(config); @@ -511,6 +534,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { 90.0, null, null, + null, null); DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder() .setId(jobId) @@ -566,6 +590,73 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { "Finished analysis"); } + public void testWithCustomFeatureProcessors() throws Exception { + initialize("regression_with_custom_feature_processors"); + String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction"; + indexData(sourceIndex, 300, 50); + + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, + new Regression( + DEPENDENT_VARIABLE_FIELD, + BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(), + null, + null, + null, + null, + null, + Arrays.asList( + new OneHotEncoding(DISCRETE_NUMERICAL_FEATURE_FIELD, + Collections.singletonMap(DISCRETE_NUMERICAL_FEATURE_VALUES.get(0).toString(), "tenner"), true) + )) + ); + putAnalytics(config); + + assertIsStopped(jobId); + assertProgressIsZero(jobId); + + startAnalytics(jobId); + waitUntilAnalyticsIsStopped(jobId); + + // for debugging + SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); + for (SearchHit hit : sourceData.getHits()) { + Map destDoc = getDestDoc(config, hit); + Map resultsObject = getMlResultsObjectFromDestDoc(destDoc); + + assertThat(resultsObject.containsKey(predictedClassField), is(true)); + assertThat(resultsObject.containsKey("is_training"), is(true)); + assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD))); + } + + assertProgressComplete(jobId); + assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); + assertModelStatePersisted(stateDocId()); + assertInferenceModelPersisted(jobId); + assertMlResultsFieldMappings(destIndex, predictedClassField, "double"); + assertThatAuditMessagesMatch(jobId, + "Created analytics with analysis type [regression]", + "Estimated memory usage for this analytics to be", + "Starting analytics on node", + "Started analytics", + "Creating destination index [" + destIndex + "]", + "Started reindexing to destination index [" + destIndex + "]", + "Finished reindexing to destination index [" + destIndex + "]", + "Started loading data", + "Started analyzing", + "Started writing results", + "Finished analysis"); + GetTrainedModelsAction.Response response = client().execute(GetTrainedModelsAction.INSTANCE, + new GetTrainedModelsAction.Request(jobId + "*", true, Collections.emptyList())).actionGet(); + assertThat(response.getResources().results().size(), equalTo(1)); + TrainedModelConfig modelConfig = response.getResources().results().get(0); + modelConfig.ensureParsedDefinition(xContentRegistry()); + assertThat(modelConfig.getModelDefinition().getPreProcessors().size(), greaterThan(0)); + for (int i = 0; i < modelConfig.getModelDefinition().getPreProcessors().size(); i++) { + PreProcessor preProcessor = modelConfig.getModelDefinition().getPreProcessors().get(i); + assertThat(preProcessor.isCustom(), equalTo(i == 0)); + } + } + private void initialize(String jobId) { this.jobId = jobId; this.sourceIndex = jobId + "_source_index"; diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java index 74581ac3d45..624aee9e41b 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java @@ -71,7 +71,7 @@ public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase { analyticsConfig, new DataFrameAnalyticsAuditor(client(), "test-node"), (ex) -> { throw new ElasticsearchException(ex); }, - new ExtractedFields(extractedFieldList, Collections.emptyMap()) + new ExtractedFields(extractedFieldList, Collections.emptyList(), Collections.emptyMap()) ); //Accuracy for size is not tested here diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsConfigProviderIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsConfigProviderIT.java index 0721d56290d..dc0d390760b 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsConfigProviderIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsConfigProviderIT.java @@ -27,6 +27,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigUpdate; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; @@ -171,9 +172,9 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase { blockingCall( actionListener -> configProvider.put(initialConfig, emptyMap(), actionListener), configHolder, exceptionHolder); + assertNoException(exceptionHolder); assertThat(configHolder.get(), is(notNullValue())); assertThat(configHolder.get(), is(equalTo(initialConfig))); - assertThat(exceptionHolder.get(), is(nullValue())); } { // Update that changes description AtomicReference updatedConfigHolder = new AtomicReference<>(); @@ -188,7 +189,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase { actionListener -> configProvider.update(configUpdate, emptyMap(), ClusterState.EMPTY_STATE, actionListener), updatedConfigHolder, exceptionHolder); - + assertNoException(exceptionHolder); assertThat(updatedConfigHolder.get(), is(notNullValue())); assertThat( updatedConfigHolder.get(), @@ -196,7 +197,6 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase { new DataFrameAnalyticsConfig.Builder(initialConfig) .setDescription("description-1") .build()))); - assertThat(exceptionHolder.get(), is(nullValue())); } { // Update that changes model memory limit AtomicReference updatedConfigHolder = new AtomicReference<>(); @@ -212,6 +212,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase { updatedConfigHolder, exceptionHolder); + assertNoException(exceptionHolder); assertThat(updatedConfigHolder.get(), is(notNullValue())); assertThat( updatedConfigHolder.get(), @@ -220,7 +221,6 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase { .setDescription("description-1") .setModelMemoryLimit(new ByteSizeValue(1024)) .build()))); - assertThat(exceptionHolder.get(), is(nullValue())); } { // Noop update AtomicReference updatedConfigHolder = new AtomicReference<>(); @@ -233,6 +233,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase { updatedConfigHolder, exceptionHolder); + assertNoException(exceptionHolder); assertThat(updatedConfigHolder.get(), is(notNullValue())); assertThat( updatedConfigHolder.get(), @@ -241,7 +242,6 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase { .setDescription("description-1") .setModelMemoryLimit(new ByteSizeValue(1024)) .build()))); - assertThat(exceptionHolder.get(), is(nullValue())); } { // Update that changes both description and model memory limit AtomicReference updatedConfigHolder = new AtomicReference<>(); @@ -258,6 +258,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase { updatedConfigHolder, exceptionHolder); + assertNoException(exceptionHolder); assertThat(updatedConfigHolder.get(), is(notNullValue())); assertThat( updatedConfigHolder.get(), @@ -266,7 +267,6 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase { .setDescription("description-2") .setModelMemoryLimit(new ByteSizeValue(2048)) .build()))); - assertThat(exceptionHolder.get(), is(nullValue())); } { // Update that applies security headers Map securityHeaders = Collections.singletonMap("_xpack_security_authentication", "dummy"); @@ -281,6 +281,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase { updatedConfigHolder, exceptionHolder); + assertNoException(exceptionHolder); assertThat(updatedConfigHolder.get(), is(notNullValue())); assertThat( updatedConfigHolder.get(), @@ -290,7 +291,6 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase { .setModelMemoryLimit(new ByteSizeValue(2048)) .setHeaders(securityHeaders) .build()))); - assertThat(exceptionHolder.get(), is(nullValue())); } } @@ -371,6 +371,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase { List namedXContent = new ArrayList<>(); namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); return new NamedXContentRegistry(namedXContent); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/TimeBasedExtractedFields.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/TimeBasedExtractedFields.java index eb13f2395ef..580f0c349a1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/TimeBasedExtractedFields.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/TimeBasedExtractedFields.java @@ -28,7 +28,9 @@ public class TimeBasedExtractedFields extends ExtractedFields { private final ExtractedField timeField; public TimeBasedExtractedFields(ExtractedField timeField, List allFields) { - super(allFields, Collections.emptyMap()); + super(allFields, + Collections.emptyList(), + Collections.emptyMap()); if (!allFields.contains(timeField)) { throw new IllegalArgumentException("timeField should also be contained in allFields"); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java index a82fc92a675..6872585929b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java @@ -28,15 +28,18 @@ import org.elasticsearch.search.fetch.StoredFieldsContext; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.dataframe.DestinationIndex; import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitter; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; +import org.elasticsearch.xpack.ml.extractor.ProcessedField; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; @@ -46,6 +49,7 @@ import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; import java.util.stream.Collectors; +import java.util.stream.Stream; /** * An implementation that extracts data from elasticsearch using search and scroll on a client. @@ -67,10 +71,29 @@ public class DataFrameDataExtractor { private boolean hasNext; private boolean searchHasShardFailure; private final CachedSupplier trainTestSplitter; + // These are fields that are sent directly to the analytics process + // They are not passed through a feature_processor + private final String[] organicFeatures; + // These are the output field names for the feature_processors + private final String[] processedFeatures; + private final Map extractedFieldsByName; DataFrameDataExtractor(Client client, DataFrameDataExtractorContext context) { this.client = Objects.requireNonNull(client); this.context = Objects.requireNonNull(context); + Set processedFieldInputs = context.extractedFields.getProcessedFieldInputs(); + this.organicFeatures = context.extractedFields.getAllFields() + .stream() + .map(ExtractedField::getName) + .filter(f -> processedFieldInputs.contains(f) == false) + .toArray(String[]::new); + this.processedFeatures = context.extractedFields.getProcessedFields() + .stream() + .map(ProcessedField::getOutputFieldNames) + .flatMap(List::stream) + .toArray(String[]::new); + this.extractedFieldsByName = new LinkedHashMap<>(); + context.extractedFields.getAllFields().forEach(f -> this.extractedFieldsByName.put(f.getName(), f)); hasNext = true; searchHasShardFailure = false; this.trainTestSplitter = new CachedSupplier<>(context.trainTestSplitterFactory::create); @@ -188,26 +211,78 @@ public class DataFrameDataExtractor { return rows; } + private String extractNonProcessedValues(SearchHit hit, String organicFeature) { + ExtractedField field = extractedFieldsByName.get(organicFeature); + Object[] values = field.value(hit); + if (values.length == 1 && isValidValue(values[0])) { + return Objects.toString(values[0]); + } + if (values.length == 0 && context.supportsRowsWithMissingValues) { + // if values is empty then it means it's a missing value + return NULL_VALUE; + } + // we are here if we have a missing value but the analysis does not support those + // or the value type is not supported (e.g. arrays, etc.) + return null; + } + + private String[] extractProcessedValue(ProcessedField processedField, SearchHit hit) { + Object[] values = processedField.value(hit, extractedFieldsByName::get); + if (values.length == 0 && context.supportsRowsWithMissingValues == false) { + return null; + } + final String[] extractedValue = new String[processedField.getOutputFieldNames().size()]; + for (int i = 0; i < processedField.getOutputFieldNames().size(); i++) { + extractedValue[i] = NULL_VALUE; + } + // if values is empty then it means it's a missing value + if (values.length == 0) { + return extractedValue; + } + + if (values.length != processedField.getOutputFieldNames().size()) { + throw ExceptionsHelper.badRequestException( + "field_processor [{}] output size expected to be [{}], instead it was [{}]", + processedField.getProcessorName(), + processedField.getOutputFieldNames().size(), + values.length); + } + + for (int i = 0; i < processedField.getOutputFieldNames().size(); ++i) { + Object value = values[i]; + if (value == null && context.supportsRowsWithMissingValues) { + continue; + } + if (isValidValue(value) == false) { + // we are here if we have a missing value but the analysis does not support those + // or the value type is not supported (e.g. arrays, etc.) + return null; + } + extractedValue[i] = Objects.toString(value); + } + return extractedValue; + } + private Row createRow(SearchHit hit) { - String[] extractedValues = new String[context.extractedFields.getAllFields().size()]; - for (int i = 0; i < extractedValues.length; ++i) { - ExtractedField field = context.extractedFields.getAllFields().get(i); - Object[] values = field.value(hit); - if (values.length == 1 && (values[0] instanceof Number || values[0] instanceof String)) { - extractedValues[i] = Objects.toString(values[0]); - } else { - if (values.length == 0 && context.supportsRowsWithMissingValues) { - // if values is empty then it means it's a missing value - extractedValues[i] = NULL_VALUE; - } else { - // we are here if we have a missing value but the analysis does not support those - // or the value type is not supported (e.g. arrays, etc.) - extractedValues = null; - break; - } + String[] extractedValues = new String[organicFeatures.length + processedFeatures.length]; + int i = 0; + for (String organicFeature : organicFeatures) { + String extractedValue = extractNonProcessedValues(hit, organicFeature); + if (extractedValue == null) { + return new Row(null, hit, true); + } + extractedValues[i++] = extractedValue; + } + for (ProcessedField processedField : context.extractedFields.getProcessedFields()) { + String[] processedValues = extractProcessedValue(processedField, hit); + if (processedValues == null) { + return new Row(null, hit, true); + } + for (String processedValue : processedValues) { + extractedValues[i++] = processedValue; } } - boolean isTraining = extractedValues == null ? false : trainTestSplitter.get().isTraining(extractedValues); + boolean isTraining = trainTestSplitter.get().isTraining(extractedValues); return new Row(extractedValues, hit, isTraining); } @@ -241,7 +316,7 @@ public class DataFrameDataExtractor { } public List getFieldNames() { - return context.extractedFields.getAllFields().stream().map(ExtractedField::getName).collect(Collectors.toList()); + return Stream.concat(Arrays.stream(organicFeatures), Arrays.stream(processedFeatures)).collect(Collectors.toList()); } public ExtractedFields getExtractedFields() { @@ -253,12 +328,12 @@ public class DataFrameDataExtractor { SearchResponse searchResponse = executeSearchRequest(searchRequestBuilder); long rows = searchResponse.getHits().getTotalHits().value; LOGGER.debug("[{}] Data summary rows [{}]", context.jobId, rows); - return new DataSummary(rows, context.extractedFields.getAllFields().size()); + return new DataSummary(rows, organicFeatures.length + processedFeatures.length); } public void collectDataSummaryAsync(ActionListener dataSummaryActionListener) { SearchRequestBuilder searchRequestBuilder = buildDataSummarySearchRequestBuilder(); - final int numberOfFields = context.extractedFields.getAllFields().size(); + final int numberOfFields = organicFeatures.length + processedFeatures.length; ClientHelper.executeWithHeadersAsync(context.headers, ClientHelper.ML_ORIGIN, @@ -298,7 +373,11 @@ public class DataFrameDataExtractor { } public Set getCategoricalFields(DataFrameAnalysis analysis) { - return ExtractedFieldsDetector.getCategoricalFields(context.extractedFields, analysis); + return ExtractedFieldsDetector.getCategoricalOutputFields(context.extractedFields, analysis); + } + + private static boolean isValidValue(Object value) { + return value instanceof Number || value instanceof String; } public static class DataSummary { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java index 6378723cc0f..7a374c455de 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java @@ -13,27 +13,33 @@ import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse; import org.elasticsearch.common.Strings; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.regex.Regex; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.mapper.BooleanFieldMapper; import org.elasticsearch.index.mapper.ObjectMapper; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.dataframe.analyses.FieldCardinalityConstraint; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.dataframe.analyses.RequiredField; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Types; import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.NameResolver; import org.elasticsearch.xpack.ml.dataframe.DestinationIndex; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; +import org.elasticsearch.xpack.ml.extractor.ProcessedField; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; +import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; @@ -60,7 +66,9 @@ public class ExtractedFieldsDetector { private final FieldCapabilitiesResponse fieldCapabilitiesResponse; private final Map cardinalitiesForFieldsWithConstraints; - ExtractedFieldsDetector(DataFrameAnalyticsConfig config, int docValueFieldsLimit, FieldCapabilitiesResponse fieldCapabilitiesResponse, + ExtractedFieldsDetector(DataFrameAnalyticsConfig config, + int docValueFieldsLimit, + FieldCapabilitiesResponse fieldCapabilitiesResponse, Map cardinalitiesForFieldsWithConstraints) { this.config = Objects.requireNonNull(config); this.docValueFieldsLimit = docValueFieldsLimit; @@ -69,23 +77,39 @@ public class ExtractedFieldsDetector { } public Tuple> detect() { + List processedFields = extractFeatureProcessors() + .stream() + .map(ProcessedField::new) + .collect(Collectors.toList()); TreeSet fieldSelection = new TreeSet<>(Comparator.comparing(FieldSelection::getName)); - Set fields = getIncludedFields(fieldSelection); + Set fields = getIncludedFields(fieldSelection, + processedFields.stream() + .map(ProcessedField::getInputFieldNames) + .flatMap(List::stream) + .collect(Collectors.toSet())); checkFieldsHaveCompatibleTypes(fields); checkRequiredFields(fields); checkFieldsWithCardinalityLimit(); - ExtractedFields extractedFields = detectExtractedFields(fields, fieldSelection); + ExtractedFields extractedFields = detectExtractedFields(fields, fieldSelection, processedFields); addIncludedFields(extractedFields, fieldSelection); + checkOutputFeatureUniqueness(processedFields, fields); + return Tuple.tuple(extractedFields, Collections.unmodifiableList(new ArrayList<>(fieldSelection))); } - private Set getIncludedFields(Set fieldSelection) { + private Set getIncludedFields(Set fieldSelection, Set requiredFieldsForProcessors) { Set fields = new TreeSet<>(fieldCapabilitiesResponse.get().keySet()); + validateFieldsRequireForProcessors(requiredFieldsForProcessors); fields.removeAll(IGNORE_FIELDS); removeFieldsUnderResultsField(fields); removeObjects(fields); applySourceFiltering(fields); + if (fields.containsAll(requiredFieldsForProcessors) == false) { + throw ExceptionsHelper.badRequestException( + "fields {} required by field_processors are not included in source filtering.", + Sets.difference(requiredFieldsForProcessors, fields)); + } FetchSourceContext analyzedFields = config.getAnalyzedFields(); // If the user has not explicitly included fields we'll include all compatible fields @@ -93,20 +117,63 @@ public class ExtractedFieldsDetector { removeFieldsWithIncompatibleTypes(fields, fieldSelection); } includeAndExcludeFields(fields, fieldSelection); + if (fields.containsAll(requiredFieldsForProcessors) == false) { + throw ExceptionsHelper.badRequestException( + "fields {} required by field_processors are not included in the analyzed_fields.", + Sets.difference(requiredFieldsForProcessors, fields)); + } return fields; } + private void validateFieldsRequireForProcessors(Set processorFields) { + Set fieldsForProcessor = new HashSet<>(processorFields); + removeFieldsUnderResultsField(fieldsForProcessor); + if (fieldsForProcessor.size() < processorFields.size()) { + throw ExceptionsHelper.badRequestException("fields contained in results field [{}] cannot be used in a feature_processor", + config.getDest().getResultsField()); + } + removeObjects(fieldsForProcessor); + if (fieldsForProcessor.size() < processorFields.size()) { + throw ExceptionsHelper.badRequestException("fields for feature_processors must not be objects"); + } + fieldsForProcessor.removeAll(IGNORE_FIELDS); + if (fieldsForProcessor.size() < processorFields.size()) { + throw ExceptionsHelper.badRequestException("the following fields cannot be used in feature_processors {}", IGNORE_FIELDS); + } + List fieldsMissingInMapping = processorFields.stream() + .filter(f -> fieldCapabilitiesResponse.get().containsKey(f) == false) + .collect(Collectors.toList()); + if (fieldsMissingInMapping.isEmpty() == false) { + throw ExceptionsHelper.badRequestException( + "the fields {} were not found in the field capabilities of the source indices [{}]. " + + "Fields must exist and be mapped to be used in feature_processors.", + fieldsMissingInMapping, + Strings.arrayToCommaDelimitedString(config.getSource().getIndex())); + } + List processedRequiredFields = config.getAnalysis() + .getRequiredFields() + .stream() + .map(RequiredField::getName) + .filter(processorFields::contains) + .collect(Collectors.toList()); + if (processedRequiredFields.isEmpty() == false) { + throw ExceptionsHelper.badRequestException( + "required analysis fields {} cannot be used in a feature_processor", + processedRequiredFields); + } + } + private void removeFieldsUnderResultsField(Set fields) { - String resultsField = config.getDest().getResultsField(); + final String resultsFieldPrefix = config.getDest().getResultsField() + "."; Iterator fieldsIterator = fields.iterator(); while (fieldsIterator.hasNext()) { String field = fieldsIterator.next(); - if (field.startsWith(resultsField + ".")) { + if (field.startsWith(resultsFieldPrefix)) { fieldsIterator.remove(); } } - fields.removeIf(field -> field.startsWith(resultsField + ".")); + fields.removeIf(field -> field.startsWith(resultsFieldPrefix)); } private void removeObjects(Set fields) { @@ -287,9 +354,23 @@ public class ExtractedFieldsDetector { } } - private ExtractedFields detectExtractedFields(Set fields, Set fieldSelection) { - ExtractedFields extractedFields = ExtractedFields.build(fields, Collections.emptySet(), fieldCapabilitiesResponse, - cardinalitiesForFieldsWithConstraints); + private List extractFeatureProcessors() { + if (config.getAnalysis() instanceof Classification) { + return ((Classification)config.getAnalysis()).getFeatureProcessors(); + } else if (config.getAnalysis() instanceof Regression) { + return ((Regression)config.getAnalysis()).getFeatureProcessors(); + } + return Collections.emptyList(); + } + + private ExtractedFields detectExtractedFields(Set fields, + Set fieldSelection, + List processedFields) { + ExtractedFields extractedFields = ExtractedFields.build(fields, + Collections.emptySet(), + fieldCapabilitiesResponse, + cardinalitiesForFieldsWithConstraints, + processedFields); boolean preferSource = extractedFields.getDocValueFields().size() > docValueFieldsLimit; extractedFields = deduplicateMultiFields(extractedFields, preferSource, fieldSelection); if (preferSource) { @@ -304,10 +385,15 @@ public class ExtractedFieldsDetector { return extractedFields; } - private ExtractedFields deduplicateMultiFields(ExtractedFields extractedFields, boolean preferSource, + private ExtractedFields deduplicateMultiFields(ExtractedFields extractedFields, + boolean preferSource, Set fieldSelection) { - Set requiredFields = config.getAnalysis().getRequiredFields().stream().map(RequiredField::getName) + Set requiredFields = config.getAnalysis() + .getRequiredFields() + .stream() + .map(RequiredField::getName) .collect(Collectors.toSet()); + Set processorInputFields = extractedFields.getProcessedFieldInputs(); Map nameOrParentToField = new LinkedHashMap<>(); for (ExtractedField currentField : extractedFields.getAllFields()) { String nameOrParent = currentField.isMultiField() ? currentField.getParentField() : currentField.getName(); @@ -315,15 +401,37 @@ public class ExtractedFieldsDetector { if (existingField != null) { ExtractedField parent = currentField.isMultiField() ? existingField : currentField; ExtractedField multiField = currentField.isMultiField() ? currentField : existingField; + // If required fields contains parent or multifield and the processor input fields reference the other, that is an error + // we should not allow processing of data that is required. + if ((requiredFields.contains(parent.getName()) && processorInputFields.contains(multiField.getName())) + || (requiredFields.contains(multiField.getName()) && processorInputFields.contains(parent.getName()))) { + throw ExceptionsHelper.badRequestException( + "feature_processors cannot be applied to required fields for analysis; multi-field [{}] parent [{}]", + multiField.getName(), + parent.getName()); + } + // If processor input fields have BOTH, we need to keep both. + if (processorInputFields.contains(parent.getName()) && processorInputFields.contains(multiField.getName())) { + throw ExceptionsHelper.badRequestException( + "feature_processors refer to both multi-field [{}] and parent [{}]. Please only refer to one or the other", + multiField.getName(), + parent.getName()); + } nameOrParentToField.put(nameOrParent, - chooseMultiFieldOrParent(preferSource, requiredFields, parent, multiField, fieldSelection)); + chooseMultiFieldOrParent(preferSource, requiredFields, processorInputFields, parent, multiField, fieldSelection)); } } - return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()), cardinalitiesForFieldsWithConstraints); + return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()), + extractedFields.getProcessedFields(), + cardinalitiesForFieldsWithConstraints); } - private ExtractedField chooseMultiFieldOrParent(boolean preferSource, Set requiredFields, ExtractedField parent, - ExtractedField multiField, Set fieldSelection) { + private ExtractedField chooseMultiFieldOrParent(boolean preferSource, + Set requiredFields, + Set processorInputFields, + ExtractedField parent, + ExtractedField multiField, + Set fieldSelection) { // Check requirements first if (requiredFields.contains(parent.getName())) { addExcludedField(multiField.getName(), "[" + parent.getName() + "] is required instead", fieldSelection); @@ -333,6 +441,19 @@ public class ExtractedFieldsDetector { addExcludedField(parent.getName(), "[" + multiField.getName() + "] is required instead", fieldSelection); return multiField; } + // Choose the one required by our processors + if (processorInputFields.contains(parent.getName())) { + addExcludedField(multiField.getName(), + "[" + parent.getName() + "] is referenced by feature_processors instead", + fieldSelection); + return parent; + } + if (processorInputFields.contains(multiField.getName())) { + addExcludedField(parent.getName(), + "[" + multiField.getName() + "] is referenced by feature_processors instead", + fieldSelection); + return multiField; + } // If both are multi-fields it means there are several. In this case parent is the previous multi-field // we selected. We'll just keep that. @@ -370,7 +491,9 @@ public class ExtractedFieldsDetector { for (ExtractedField field : extractedFields.getAllFields()) { adjusted.add(field.supportsFromSource() ? field.newFromSource() : field); } - return new ExtractedFields(adjusted, cardinalitiesForFieldsWithConstraints); + return new ExtractedFields(adjusted, + extractedFields.getProcessedFields(), + cardinalitiesForFieldsWithConstraints); } private ExtractedFields fetchBooleanFieldsAsIntegers(ExtractedFields extractedFields) { @@ -387,13 +510,15 @@ public class ExtractedFieldsDetector { adjusted.add(field); } } - return new ExtractedFields(adjusted, cardinalitiesForFieldsWithConstraints); + return new ExtractedFields(adjusted, + extractedFields.getProcessedFields(), + cardinalitiesForFieldsWithConstraints); } private void addIncludedFields(ExtractedFields extractedFields, Set fieldSelection) { Set requiredFields = config.getAnalysis().getRequiredFields().stream().map(RequiredField::getName) .collect(Collectors.toSet()); - Set categoricalFields = getCategoricalFields(extractedFields, config.getAnalysis()); + Set categoricalFields = getCategoricalInputFields(extractedFields, config.getAnalysis()); for (ExtractedField includedField : extractedFields.getAllFields()) { FieldSelection.FeatureType featureType = categoricalFields.contains(includedField.getName()) ? FieldSelection.FeatureType.CATEGORICAL : FieldSelection.FeatureType.NUMERICAL; @@ -402,7 +527,38 @@ public class ExtractedFieldsDetector { } } - static Set getCategoricalFields(ExtractedFields extractedFields, DataFrameAnalysis analysis) { + static void checkOutputFeatureUniqueness(List processedFields, Set selectedFields) { + Set processInputs = processedFields.stream() + .map(ProcessedField::getInputFieldNames) + .flatMap(List::stream) + .collect(Collectors.toSet()); + // All analysis fields that we include that are NOT processed + // This indicates that they are sent as is + Set organicFields = Sets.difference(selectedFields, processInputs); + + Set processedFeatures = new HashSet<>(); + Set duplicatedFields = new HashSet<>(); + for (ProcessedField processedField : processedFields) { + for (String output : processedField.getOutputFieldNames()) { + if (processedFeatures.add(output) == false) { + duplicatedFields.add(output); + } + } + } + if (duplicatedFields.isEmpty() == false) { + throw ExceptionsHelper.badRequestException( + "feature_processors must define unique output field names; duplicate fields {}", + duplicatedFields); + } + Set duplicateOrganicAndProcessed = Sets.intersection(organicFields, processedFeatures); + if (duplicateOrganicAndProcessed.isEmpty() == false) { + throw ExceptionsHelper.badRequestException( + "feature_processors output fields must not include non-processed analysis fields; duplicate fields {}", + duplicateOrganicAndProcessed); + } + } + + static Set getCategoricalInputFields(ExtractedFields extractedFields, DataFrameAnalysis analysis) { return extractedFields.getAllFields().stream() .filter(extractedField -> analysis.getAllowedCategoricalTypes(extractedField.getName()) .containsAll(extractedField.getTypes())) @@ -410,6 +566,25 @@ public class ExtractedFieldsDetector { .collect(Collectors.toSet()); } + static Set getCategoricalOutputFields(ExtractedFields extractedFields, DataFrameAnalysis analysis) { + Set processInputFields = extractedFields.getProcessedFieldInputs(); + Set categoricalFields = extractedFields.getAllFields().stream() + .filter(extractedField -> analysis.getAllowedCategoricalTypes(extractedField.getName()) + .containsAll(extractedField.getTypes())) + .map(ExtractedField::getName) + .filter(name -> processInputFields.contains(name) == false) + .collect(Collectors.toSet()); + + extractedFields.getProcessedFields().forEach(processedField -> + processedField.getOutputFieldNames().forEach(outputField -> { + if (analysis.getAllowedCategoricalTypes(outputField).containsAll(processedField.getOutputFieldType(outputField))) { + categoricalFields.add(outputField); + } + }) + ); + return Collections.unmodifiableSet(categoricalFields); + } + private static boolean isBoolean(Set types) { return types.size() == 1 && types.contains(BooleanFieldMapper.CONTENT_TYPE); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index 13358461e43..97cd053b13f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -178,7 +178,7 @@ public class AnalyticsProcessManager { AnalyticsProcess process = processContext.process.get(); AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get(); try { - writeHeaderRecord(dataExtractor, process); + writeHeaderRecord(dataExtractor, process, task); writeDataRows(dataExtractor, process, task); process.writeEndOfDataMessage(); process.flushStream(); @@ -268,8 +268,11 @@ public class AnalyticsProcessManager { } } - private void writeHeaderRecord(DataFrameDataExtractor dataExtractor, AnalyticsProcess process) throws IOException { + private void writeHeaderRecord(DataFrameDataExtractor dataExtractor, + AnalyticsProcess process, + DataFrameAnalyticsTask task) throws IOException { List fieldNames = dataExtractor.getFieldNames(); + LOGGER.debug(() -> new ParameterizedMessage("[{}] header row fields {}", task.getParams().getId(), fieldNames)); // We add 2 extra fields, both named dot: // - the document hash diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java index 725627ab21c..213fa1d369f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java @@ -9,6 +9,7 @@ package org.elasticsearch.xpack.ml.dataframe.process; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.LatchedActionListener; @@ -22,6 +23,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.security.user.XPackUser; import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk; @@ -34,6 +36,7 @@ import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import java.time.Instant; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; @@ -191,8 +194,21 @@ public class ChunkedTrainedModelPersister { return latch; } + private long customProcessorSize() { + List preProcessors = new ArrayList<>(); + if (analytics.getAnalysis() instanceof Classification) { + preProcessors = ((Classification) analytics.getAnalysis()).getFeatureProcessors(); + } else if (analytics.getAnalysis() instanceof Regression) { + preProcessors = ((Regression) analytics.getAnalysis()).getFeatureProcessors(); + } + return preProcessors.stream().mapToLong(PreProcessor::ramBytesUsed).sum() + + RamUsageEstimator.NUM_BYTES_OBJECT_REF * preProcessors.size(); + } + private TrainedModelConfig createTrainedModelConfig(ModelSizeInfo modelSize) { Instant createTime = Instant.now(); + // The native process does not provide estimates for the custom feature_processor objects + long customProcessorSize = customProcessorSize(); String modelId = analytics.getId() + "-" + createTime.toEpochMilli(); currentModelId.set(modelId); List fieldNames = extractedFields.getAllFields(); @@ -214,7 +230,7 @@ public class ChunkedTrainedModelPersister { .setDescription(analytics.getDescription()) .setMetadata(Collections.singletonMap("analytics_config", XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true))) - .setEstimatedHeapMemory(modelSize.ramBytesUsed()) + .setEstimatedHeapMemory(modelSize.ramBytesUsed() + customProcessorSize) .setEstimatedOperations(modelSize.numOperations()) .setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable)) .setLicenseLevel(License.OperationMode.PLATINUM.description()) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java index ab314a5d218..3853ea2629a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java @@ -12,7 +12,7 @@ import org.elasticsearch.index.mapper.BooleanFieldMapper; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.ml.utils.MlStrings; -import java.util.Collection; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; @@ -21,27 +21,39 @@ import java.util.Set; import java.util.stream.Collectors; /** - * The fields the datafeed has to extract + * The fields the data[feed|frame] has to extract */ public class ExtractedFields { private final List allFields; private final List docValueFields; + private final List processedFields; private final String[] sourceFields; private final Map cardinalitiesForFieldsWithConstraints; - public ExtractedFields(List allFields, Map cardinalitiesForFieldsWithConstraints) { - this.allFields = Collections.unmodifiableList(allFields); + public ExtractedFields(List allFields, + List processedFields, + Map cardinalitiesForFieldsWithConstraints) { + this.allFields = new ArrayList<>(allFields); this.docValueFields = filterFields(ExtractedField.Method.DOC_VALUE, allFields); this.sourceFields = filterFields(ExtractedField.Method.SOURCE, allFields).stream().map(ExtractedField::getSearchField) .toArray(String[]::new); this.cardinalitiesForFieldsWithConstraints = Collections.unmodifiableMap(cardinalitiesForFieldsWithConstraints); + this.processedFields = processedFields == null ? Collections.emptyList() : processedFields; + } + + public List getProcessedFields() { + return processedFields; } public List getAllFields() { return allFields; } + public Set getProcessedFieldInputs() { + return processedFields.stream().map(ProcessedField::getInputFieldNames).flatMap(List::stream).collect(Collectors.toSet()); + } + public String[] getSourceFields() { return sourceFields; } @@ -58,11 +70,15 @@ public class ExtractedFields { return fields.stream().filter(field -> field.getMethod() == method).collect(Collectors.toList()); } - public static ExtractedFields build(Collection allFields, Set scriptFields, + public static ExtractedFields build(Set allFields, + Set scriptFields, FieldCapabilitiesResponse fieldsCapabilities, - Map cardinalitiesForFieldsWithConstraints) { + Map cardinalitiesForFieldsWithConstraints, + List processedFields) { ExtractionMethodDetector extractionMethodDetector = new ExtractionMethodDetector(scriptFields, fieldsCapabilities); - return new ExtractedFields(allFields.stream().map(field -> extractionMethodDetector.detect(field)).collect(Collectors.toList()), + return new ExtractedFields( + allFields.stream().map(extractionMethodDetector::detect).collect(Collectors.toList()), + processedFields, cardinalitiesForFieldsWithConstraints); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ProcessedField.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ProcessedField.java new file mode 100644 index 00000000000..50f13f94086 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ProcessedField.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.extractor; + +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.Function; + +public class ProcessedField { + private final PreProcessor preProcessor; + + public ProcessedField(PreProcessor processor) { + this.preProcessor = Objects.requireNonNull(processor); + } + + public List getInputFieldNames() { + return preProcessor.inputFields(); + } + + public List getOutputFieldNames() { + return preProcessor.outputFields(); + } + + public Set getOutputFieldType(String outputField) { + return Collections.singleton(preProcessor.getOutputFieldType(outputField)); + } + + public Object[] value(SearchHit hit, Function fieldExtractor) { + Map inputs = new HashMap<>(preProcessor.inputFields().size(), 1.0f); + for (String field : preProcessor.inputFields()) { + ExtractedField extractedField = fieldExtractor.apply(field); + if (extractedField == null) { + return new Object[0]; + } + Object[] values = extractedField.value(hit); + if (values == null || values.length == 0) { + continue; + } + final Object value = values[0]; + if (values.length == 1 && (value instanceof String || value instanceof Number)) { + inputs.put(field, value); + } + } + preProcessor.process(inputs); + return preProcessor.outputFields().stream().map(inputs::get).toArray(); + } + + public String getProcessorName() { + return preProcessor.getName(); + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlSingleNodeTestCase.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlSingleNodeTestCase.java index de2316948ff..f08ffb58757 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlSingleNodeTestCase.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlSingleNodeTestCase.java @@ -128,4 +128,11 @@ public abstract class MlSingleNodeTestCase extends ESSingleNodeTestCase { return responseHolder.get(); } + public static void assertNoException(AtomicReference error) throws Exception { + if (error.get() == null) { + return; + } + throw error.get(); + } + } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java index 7280688e713..71b636d2cfe 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java @@ -15,8 +15,10 @@ import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.client.Client; import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.rest.RestStatus; @@ -27,10 +29,13 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory; import org.elasticsearch.xpack.ml.extractor.DocValueField; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; +import org.elasticsearch.xpack.ml.extractor.ProcessedField; import org.elasticsearch.xpack.ml.extractor.SourceField; import org.elasticsearch.xpack.ml.test.SearchHitBuilder; import org.junit.Before; @@ -45,8 +50,10 @@ import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Queue; +import java.util.function.Function; import java.util.stream.Collectors; +import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; @@ -83,7 +90,9 @@ public class DataFrameDataExtractorTests extends ESTestCase { query = QueryBuilders.matchAllQuery(); extractedFields = new ExtractedFields(Arrays.asList( new DocValueField("field_1", Collections.singleton("keyword")), - new DocValueField("field_2", Collections.singleton("keyword"))), Collections.emptyMap()); + new DocValueField("field_2", Collections.singleton("keyword"))), + Collections.emptyList(), + Collections.emptyMap()); scrollSize = 1000; headers = Collections.emptyMap(); @@ -304,7 +313,9 @@ public class DataFrameDataExtractorTests extends ESTestCase { // Explicit cast of ExtractedField args necessary for Eclipse due to https://bugs.eclipse.org/bugs/show_bug.cgi?id=530915 extractedFields = new ExtractedFields(Arrays.asList( (ExtractedField) new DocValueField("field_1", Collections.singleton("keyword")), - (ExtractedField) new SourceField("field_2", Collections.singleton("text"))), Collections.emptyMap()); + (ExtractedField) new SourceField("field_2", Collections.singleton("text"))), + Collections.emptyList(), + Collections.emptyMap()); TestExtractor dataExtractor = createExtractor(false, false); @@ -445,7 +456,9 @@ public class DataFrameDataExtractorTests extends ESTestCase { (ExtractedField) new DocValueField("field_integer", Collections.singleton("integer")), (ExtractedField) new DocValueField("field_long", Collections.singleton("long")), (ExtractedField) new DocValueField("field_keyword", Collections.singleton("keyword")), - (ExtractedField) new SourceField("field_text", Collections.singleton("text"))), Collections.emptyMap()); + (ExtractedField) new SourceField("field_text", Collections.singleton("text"))), + Collections.emptyList(), + Collections.emptyMap()); TestExtractor dataExtractor = createExtractor(true, true); assertThat(dataExtractor.getCategoricalFields(OutlierDetectionTests.createRandom()), empty()); @@ -465,12 +478,100 @@ public class DataFrameDataExtractorTests extends ESTestCase { containsInAnyOrder("field_keyword", "field_text", "field_boolean")); } + public void testGetFieldNames_GivenProcessesFeatures() { + // Explicit cast of ExtractedField args necessary for Eclipse due to https://bugs.eclipse.org/bugs/show_bug.cgi?id=530915 + extractedFields = new ExtractedFields(Arrays.asList( + (ExtractedField) new DocValueField("field_boolean", Collections.singleton("boolean")), + (ExtractedField) new DocValueField("field_float", Collections.singleton("float")), + (ExtractedField) new DocValueField("field_double", Collections.singleton("double")), + (ExtractedField) new DocValueField("field_byte", Collections.singleton("byte")), + (ExtractedField) new DocValueField("field_short", Collections.singleton("short")), + (ExtractedField) new DocValueField("field_integer", Collections.singleton("integer")), + (ExtractedField) new DocValueField("field_long", Collections.singleton("long")), + (ExtractedField) new DocValueField("field_keyword", Collections.singleton("keyword")), + (ExtractedField) new SourceField("field_text", Collections.singleton("text"))), + Arrays.asList( + new ProcessedField(new CategoricalPreProcessor("field_long", "animal")), + buildProcessedField("field_short", "field_1", "field_2") + ), + Collections.emptyMap()); + TestExtractor dataExtractor = createExtractor(true, true); + + assertThat(dataExtractor.getCategoricalFields(new Regression("field_double")), + containsInAnyOrder("field_keyword", "field_text", "animal")); + + List fieldNames = dataExtractor.getFieldNames(); + assertThat(fieldNames, containsInAnyOrder( + "animal", + "field_1", + "field_2", + "field_boolean", + "field_float", + "field_double", + "field_byte", + "field_integer", + "field_keyword", + "field_text")); + assertThat(dataExtractor.getFieldNames(), contains(fieldNames.toArray(new String[0]))); + } + + public void testExtractionWithProcessedFeatures() throws IOException { + extractedFields = new ExtractedFields(Arrays.asList( + new DocValueField("field_1", Collections.singleton("keyword")), + new DocValueField("field_2", Collections.singleton("keyword"))), + Arrays.asList( + new ProcessedField(new CategoricalPreProcessor("field_1", "animal")), + new ProcessedField(new OneHotEncoding("field_1", + Arrays.asList("11", "12") + .stream() + .collect(Collectors.toMap(Function.identity(), s -> s.equals("11") ? "field_11" : "field_12")), + true)) + ), + Collections.emptyMap()); + + TestExtractor dataExtractor = createExtractor(true, true); + + // First and only batch + SearchResponse response1 = createSearchResponse(Arrays.asList(1_1, null, 1_3), Arrays.asList(2_1, 2_2, 2_3)); + dataExtractor.setNextResponse(response1); + + // Empty + SearchResponse lastAndEmptyResponse = createEmptySearchResponse(); + dataExtractor.setNextResponse(lastAndEmptyResponse); + + assertThat(dataExtractor.hasNext(), is(true)); + + // First batch + Optional> rows = dataExtractor.next(); + assertThat(rows.isPresent(), is(true)); + assertThat(rows.get().size(), equalTo(3)); + + assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"21", "dog", "1", "0"})); + assertThat(rows.get().get(1).getValues(), + equalTo(new String[] {"22", "dog", DataFrameDataExtractor.NULL_VALUE, DataFrameDataExtractor.NULL_VALUE})); + assertThat(rows.get().get(2).getValues(), equalTo(new String[] {"23", "dog", "0", "0"})); + + assertThat(rows.get().get(0).shouldSkip(), is(false)); + assertThat(rows.get().get(1).shouldSkip(), is(false)); + assertThat(rows.get().get(2).shouldSkip(), is(false)); + } + private TestExtractor createExtractor(boolean includeSource, boolean supportsRowsWithMissingValues) { DataFrameDataExtractorContext context = new DataFrameDataExtractorContext(JOB_ID, extractedFields, indices, query, scrollSize, headers, includeSource, supportsRowsWithMissingValues, trainTestSplitterFactory); return new TestExtractor(client, context); } + private static ProcessedField buildProcessedField(String inputField, String... outputFields) { + return new ProcessedField(buildPreProcessor(inputField, outputFields)); + } + + private static PreProcessor buildPreProcessor(String inputField, String... outputFields) { + return new OneHotEncoding(inputField, + Arrays.stream(outputFields).collect(Collectors.toMap((s) -> randomAlphaOfLength(10), Function.identity())), + true); + } + private SearchResponse createSearchResponse(List field1Values, List field2Values) { assertThat(field1Values.size(), equalTo(field2Values.size())); SearchResponse searchResponse = mock(SearchResponse.class); @@ -544,4 +645,70 @@ public class DataFrameDataExtractorTests extends ESTestCase { return searchResponse; } } + + private static class CategoricalPreProcessor implements PreProcessor { + + private final List inputFields; + private final List outputFields; + + CategoricalPreProcessor(String inputField, String outputField) { + this.inputFields = Arrays.asList(inputField); + this.outputFields = Arrays.asList(outputField); + } + + @Override + public List inputFields() { + return inputFields; + } + + @Override + public List outputFields() { + return outputFields; + } + + @Override + public void process(Map fields) { + fields.put(outputFields.get(0), "dog"); + } + + @Override + public Map reverseLookup() { + return null; + } + + @Override + public boolean isCustom() { + return true; + } + + @Override + public String getOutputFieldType(String outputField) { + return "text"; + } + + @Override + public long ramBytesUsed() { + return 0; + } + + @Override + public String getWriteableName() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + + } + + @Override + public String getName() { + return null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return null; + } + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java index c0b5f19803f..744452439ac 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java @@ -15,10 +15,13 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.test.SearchHitBuilder; @@ -30,11 +33,14 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.function.Function; import java.util.stream.Collectors; import static org.hamcrest.Matchers.arrayContaining; +import static org.hamcrest.Matchers.arrayContainingInAnyOrder; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; @@ -929,12 +935,23 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { assertThat(e.getMessage(), equalTo("analyzed_fields must not include or exclude object fields: [object_field]")); } + private static FieldCapabilitiesResponse simpleFieldResponse() { + return new MockFieldCapsResponseBuilder() + .addAggregatableField("field_11", "float") + .addNonAggregatableField("field_21", "float") + .addAggregatableField("field_21.child", "float") + .addNonAggregatableField("field_31", "float") + .addAggregatableField("field_31.child", "float") + .addNonAggregatableField("object_field", "object") + .build(); + } + public void testDetect_GivenAnalyzedFieldExcludesObjectField() { FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("float_field", "float") .addNonAggregatableField("object_field", "object").build(); - analyzedFields = new FetchSourceContext(true, null, new String[] { "object_field" }); + analyzedFields = new FetchSourceContext(true, null, new String[]{"object_field"}); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); @@ -943,6 +960,177 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { assertThat(e.getMessage(), equalTo("analyzed_fields must not include or exclude object fields: [object_field]")); } + public void testDetect_givenFeatureProcessorsFailures_ResultsField() { + FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse(); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("ml.result", "foo"))), + 100, + fieldCapabilities, + Collections.emptyMap()); + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString("fields contained in results field [ml] cannot be used in a feature_processor")); + } + + public void testDetect_givenFeatureProcessorsFailures_Objects() { + FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse(); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("object_field", "foo"))), + 100, + fieldCapabilities, + Collections.emptyMap()); + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString("fields for feature_processors must not be objects")); + } + + public void testDetect_givenFeatureProcessorsFailures_ReservedFields() { + FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse(); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("_id", "foo"))), + 100, + fieldCapabilities, + Collections.emptyMap()); + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString("the following fields cannot be used in feature_processors")); + } + + public void testDetect_givenFeatureProcessorsFailures_MissingFieldFromIndex() { + FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse(); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("bar", "foo"))), + 100, + fieldCapabilities, + Collections.emptyMap()); + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString("the fields [bar] were not found in the field capabilities of the source indices")); + } + + public void testDetect_givenFeatureProcessorsFailures_UsingRequiredField() { + FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse(); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("field_31", "foo"))), + 100, + fieldCapabilities, + Collections.emptyMap()); + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString("required analysis fields [field_31] cannot be used in a feature_processor")); + } + + public void testDetect_givenFeatureProcessorsFailures_BadSourceFiltering() { + FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse(); + sourceFiltering = new FetchSourceContext(true, null, new String[]{"field_1*"}); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("field_11", "foo"))), + 100, + fieldCapabilities, + Collections.emptyMap()); + + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString("fields [field_11] required by field_processors are not included in source filtering.")); + } + + public void testDetect_givenFeatureProcessorsFailures_MissingAnalyzedField() { + FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse(); + analyzedFields = new FetchSourceContext(true, null, new String[]{"field_1*"}); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("field_11", "foo"))), + 100, + fieldCapabilities, + Collections.emptyMap()); + + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString("fields [field_11] required by field_processors are not included in the analyzed_fields")); + } + + public void testDetect_givenFeatureProcessorsFailures_RequiredMultiFields() { + FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse(); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("field_31.child", "foo"))), + 100, + fieldCapabilities, + Collections.emptyMap()); + + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString("feature_processors cannot be applied to required fields for analysis; ")); + + extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31.child", Arrays.asList(buildPreProcessor("field_31", "foo"))), + 100, + fieldCapabilities, + Collections.emptyMap()); + + ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString("feature_processors cannot be applied to required fields for analysis; ")); + } + + public void testDetect_givenFeatureProcessorsFailures_BothMultiFields() { + FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse(); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31", + Arrays.asList( + buildPreProcessor("field_21", "foo"), + buildPreProcessor("field_21.child", "bar") + )), + 100, + fieldCapabilities, + Collections.emptyMap()); + + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString("feature_processors refer to both multi-field ")); + } + + public void testDetect_givenFeatureProcessorsFailures_DuplicateOutputFields() { + FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse(); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31", + Arrays.asList( + buildPreProcessor("field_11", "foo"), + buildPreProcessor("field_21", "foo") + )), + 100, + fieldCapabilities, + Collections.emptyMap()); + + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString("feature_processors must define unique output field names; duplicate fields [foo]")); + } + + public void testDetect_withFeatureProcessors() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("field_11", "float") + .addAggregatableField("field_21", "float") + .addNonAggregatableField("field_31", "float") + .addAggregatableField("field_31.child", "float") + .addNonAggregatableField("object_field", "object") + .build(); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_11", + Arrays.asList(buildPreProcessor("field_31", "foo", "bar"))), + 100, + fieldCapabilities, + Collections.emptyMap()); + + ExtractedFields extracted = extractedFieldsDetector.detect().v1(); + + assertThat(extracted.getProcessedFieldInputs(), containsInAnyOrder("field_31")); + assertThat(extracted.getAllFields().stream().map(ExtractedField::getName).collect(Collectors.toSet()), + containsInAnyOrder("field_11", "field_21", "field_31")); + assertThat(extracted.getSourceFields(), arrayContainingInAnyOrder("field_31")); + assertThat(extracted.getDocValueFields().stream().map(ExtractedField::getName).collect(Collectors.toSet()), + containsInAnyOrder("field_21", "field_11")); + assertThat(extracted.getProcessedFields(), hasSize(1)); + } + private DataFrameAnalyticsConfig buildOutlierDetectionConfig() { return new DataFrameAnalyticsConfig.Builder() .setId("foo") @@ -954,13 +1142,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { } private DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable) { - return new DataFrameAnalyticsConfig.Builder() - .setId("foo") - .setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null, sourceFiltering)) - .setDest(new DataFrameAnalyticsDest(DEST_INDEX, RESULTS_FIELD)) - .setAnalyzedFields(analyzedFields) - .setAnalysis(new Regression(dependentVariable)) - .build(); + return buildRegressionConfig(dependentVariable, Collections.emptyList()); } private DataFrameAnalyticsConfig buildClassificationConfig(String dependentVariable) { @@ -972,6 +1154,29 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { .build(); } + private DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable, List featureprocessors) { + return new DataFrameAnalyticsConfig.Builder() + .setId("foo") + .setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null, sourceFiltering)) + .setDest(new DataFrameAnalyticsDest(DEST_INDEX, RESULTS_FIELD)) + .setAnalyzedFields(analyzedFields) + .setAnalysis(new Regression(dependentVariable, + BoostedTreeParams.builder().build(), + null, + null, + null, + null, + null, + featureprocessors)) + .build(); + } + + private static PreProcessor buildPreProcessor(String inputField, String... outputFields) { + return new OneHotEncoding(inputField, + Arrays.stream(outputFields).collect(Collectors.toMap((s) -> randomAlphaOfLength(10), Function.identity())), + true); + } + /** * We assert each field individually to get useful error messages in case of failure */ @@ -987,6 +1192,23 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { } } + public void testDetect_givenFeatureProcessorsFailures_DuplicateOutputFieldsWithUnProcessedField() { + FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse(); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31", + Arrays.asList( + buildPreProcessor("field_11", "field_21") + )), + 100, + fieldCapabilities, + Collections.emptyMap()); + + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString( + "feature_processors output fields must not include non-processed analysis fields; duplicate fields [field_21]")); + } + private static class MockFieldCapsResponseBuilder { private final Map> fieldCaps = new HashMap<>(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunnerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunnerTests.java index 2fa7b348f6d..7bbedd975b3 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunnerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunnerTests.java @@ -80,6 +80,7 @@ public class InferenceRunnerTests extends ESTestCase { public void testInferTestDocs() { ExtractedFields extractedFields = new ExtractedFields( Collections.singletonList(new SourceField("key", Collections.singleton("integer"))), + Collections.emptyList(), Collections.emptyMap()); Map doc1 = new HashMap<>(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfigTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfigTests.java index a4db8de032a..976b03a6a0d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfigTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfigTests.java @@ -63,7 +63,9 @@ public class AnalyticsProcessConfigTests extends ESTestCase { public void testToXContent_GivenOutlierDetection() throws IOException { ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( new DocValueField("field_1", Collections.singleton("double")), - new DocValueField("field_2", Collections.singleton("float"))), Collections.emptyMap()); + new DocValueField("field_2", Collections.singleton("float"))), + Collections.emptyList(), + Collections.emptyMap()); DataFrameAnalysis analysis = new OutlierDetection.Builder().build(); AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields); @@ -82,7 +84,9 @@ public class AnalyticsProcessConfigTests extends ESTestCase { ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( new DocValueField("field_1", Collections.singleton("double")), new DocValueField("field_2", Collections.singleton("float")), - new DocValueField("test_dep_var", Collections.singleton("keyword"))), Collections.emptyMap()); + new DocValueField("test_dep_var", Collections.singleton("keyword"))), + Collections.emptyList(), + Collections.emptyMap()); DataFrameAnalysis analysis = new Regression("test_dep_var"); AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields); @@ -103,7 +107,9 @@ public class AnalyticsProcessConfigTests extends ESTestCase { ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( new DocValueField("field_1", Collections.singleton("double")), new DocValueField("field_2", Collections.singleton("float")), - new DocValueField("test_dep_var", Collections.singleton("keyword"))), Collections.singletonMap("test_dep_var", 5L)); + new DocValueField("test_dep_var", Collections.singleton("keyword"))), + Collections.emptyList(), + Collections.singletonMap("test_dep_var", 5L)); DataFrameAnalysis analysis = new Classification("test_dep_var"); AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields); @@ -126,7 +132,9 @@ public class AnalyticsProcessConfigTests extends ESTestCase { ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( new DocValueField("field_1", Collections.singleton("double")), new DocValueField("field_2", Collections.singleton("float")), - new DocValueField("test_dep_var", Collections.singleton("integer"))), Collections.singletonMap("test_dep_var", 8L)); + new DocValueField("test_dep_var", Collections.singleton("integer"))), + Collections.emptyList(), + Collections.singletonMap("test_dep_var", 8L)); DataFrameAnalysis analysis = new Classification("test_dep_var"); AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java index 9fbf881d530..7a720dd8b5f 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java @@ -105,7 +105,9 @@ public class AnalyticsProcessManagerTests extends ESTestCase { OutlierDetectionTests.createRandom()).build(); dataExtractor = mock(DataFrameDataExtractor.class); when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS)); - when(dataExtractor.getExtractedFields()).thenReturn(new ExtractedFields(Collections.emptyList(), Collections.emptyMap())); + when(dataExtractor.getExtractedFields()).thenReturn(new ExtractedFields(Collections.emptyList(), + Collections.emptyList(), + Collections.emptyMap())); dataExtractorFactory = mock(DataFrameDataExtractorFactory.class); when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor); when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index 1e404360ae7..a77637bc59b 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -314,6 +314,6 @@ public class AnalyticsResultProcessorTests extends ESTestCase { trainedModelProvider, auditor, statsPersister, - new ExtractedFields(fieldNames, Collections.emptyMap())); + new ExtractedFields(fieldNames, Collections.emptyList(), Collections.emptyMap())); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java index ee01e297907..5c450df29b3 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java @@ -144,7 +144,7 @@ public class ChunkedTrainedModelPersisterTests extends ESTestCase { analyticsConfig, auditor, (unused)->{}, - new ExtractedFields(fieldNames, Collections.emptyMap())); + new ExtractedFields(fieldNames, Collections.emptyList(), Collections.emptyMap())); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java index a51eafd1d8b..d5c27f78103 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java @@ -16,6 +16,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Map; +import java.util.TreeSet; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; @@ -31,8 +32,10 @@ public class ExtractedFieldsTests extends ESTestCase { ExtractedField scriptField2 = new ScriptField("scripted2"); ExtractedField sourceField1 = new SourceField("src1", Collections.singleton("text")); ExtractedField sourceField2 = new SourceField("src2", Collections.singleton("text")); - ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( - docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2), Collections.emptyMap()); + ExtractedFields extractedFields = new ExtractedFields( + Arrays.asList(docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2), + Collections.emptyList(), + Collections.emptyMap()); assertThat(extractedFields.getAllFields().size(), equalTo(6)); assertThat(extractedFields.getDocValueFields().stream().map(ExtractedField::getName).toArray(String[]::new), @@ -53,8 +56,11 @@ public class ExtractedFieldsTests extends ESTestCase { when(fieldCapabilitiesResponse.getField("value")).thenReturn(valueCaps); when(fieldCapabilitiesResponse.getField("airline")).thenReturn(airlineCaps); - ExtractedFields extractedFields = ExtractedFields.build(Arrays.asList("time", "value", "airline", "airport"), - new HashSet<>(Collections.singletonList("airport")), fieldCapabilitiesResponse, Collections.emptyMap()); + ExtractedFields extractedFields = ExtractedFields.build(new TreeSet<>(Arrays.asList("time", "value", "airline", "airport")), + new HashSet<>(Collections.singletonList("airport")), + fieldCapabilitiesResponse, + Collections.emptyMap(), + Collections.emptyList()); assertThat(extractedFields.getDocValueFields().size(), equalTo(2)); assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("time")); @@ -76,8 +82,8 @@ public class ExtractedFieldsTests extends ESTestCase { when(fieldCapabilitiesResponse.getField("airport")).thenReturn(text); when(fieldCapabilitiesResponse.getField("airport.keyword")).thenReturn(keyword); - ExtractedFields extractedFields = ExtractedFields.build(Arrays.asList("airline.text", "airport.keyword"), - Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap()); + ExtractedFields extractedFields = ExtractedFields.build(new TreeSet<>(Arrays.asList("airline.text", "airport.keyword")), + Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap(), Collections.emptyList()); assertThat(extractedFields.getDocValueFields().size(), equalTo(1)); assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("airport.keyword")); @@ -112,14 +118,18 @@ public class ExtractedFieldsTests extends ESTestCase { assertThat(mapped.getName(), equalTo(aBool.getName())); assertThat(mapped.getMethod(), equalTo(aBool.getMethod())); assertThat(mapped.supportsFromSource(), is(false)); - expectThrows(UnsupportedOperationException.class, () -> mapped.newFromSource()); + expectThrows(UnsupportedOperationException.class, mapped::newFromSource); } public void testBuildGivenFieldWithoutMappings() { FieldCapabilitiesResponse fieldCapabilitiesResponse = mock(FieldCapabilitiesResponse.class); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> ExtractedFields.build( - Collections.singletonList("value"), Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap())); + Collections.singleton("value"), + Collections.emptySet(), + fieldCapabilitiesResponse, + Collections.emptyMap(), + Collections.emptyList())); assertThat(e.getMessage(), equalTo("cannot retrieve field [value] because it has no mappings")); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ProcessedFieldTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ProcessedFieldTests.java new file mode 100644 index 00000000000..48604833f08 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ProcessedFieldTests.java @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.extractor; + +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; +import org.elasticsearch.xpack.ml.test.SearchHitBuilder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.hamcrest.Matchers.arrayContaining; +import static org.hamcrest.Matchers.emptyArray; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasItems; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ProcessedFieldTests extends ESTestCase { + + public void testOneHotGetters() { + String inputField = "foo"; + ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz")); + assertThat(processedField.getInputFieldNames(), hasItems(inputField)); + assertThat(processedField.getOutputFieldNames(), hasItems("bar_column", "baz_column")); + assertThat(processedField.getOutputFieldType("bar_column"), equalTo(Collections.singleton("integer"))); + assertThat(processedField.getOutputFieldType("baz_column"), equalTo(Collections.singleton("integer"))); + assertThat(processedField.getProcessorName(), equalTo(OneHotEncoding.NAME.getPreferredName())); + } + + public void testMissingExtractor() { + String inputField = "foo"; + ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz")); + assertThat(processedField.value(makeHit(), (s) -> null), emptyArray()); + } + + public void testMissingInputValues() { + String inputField = "foo"; + ExtractedField extractedField = makeExtractedField(new Object[0]); + ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz")); + assertThat(processedField.value(makeHit(), (s) -> extractedField), arrayContaining(is(nullValue()), is(nullValue()))); + } + + public void testProcessedField() { + ProcessedField processedField = new ProcessedField(makePreProcessor("foo", "bar", "baz")); + assertThat(processedField.value(makeHit(), (s) -> makeExtractedField(new Object[] { "bar" })), arrayContaining(1, 0)); + assertThat(processedField.value(makeHit(), (s) -> makeExtractedField(new Object[] { "baz" })), arrayContaining(0, 1)); + } + + private static PreProcessor makePreProcessor(String inputField, String... expectedExtractedValues) { + return new OneHotEncoding(inputField, + Arrays.stream(expectedExtractedValues).collect(Collectors.toMap(Function.identity(), (s) -> s + "_column")), + true); + } + + private static ExtractedField makeExtractedField(Object[] value) { + ExtractedField extractedField = mock(ExtractedField.class); + when(extractedField.value(any())).thenReturn(value); + return extractedField; + } + + private static SearchHit makeHit() { + return new SearchHitBuilder(42).addField("a_keyword", "bar").build(); + } + +}