feature_processors allow users to create custom features from individual document fields. These `feature_processors` are the same object as the trained model's pre_processors. They are passed to the native process and the native process then appends them to the pre_processor array in the inference model. closes https://github.com/elastic/elasticsearch/issues/59327
This commit is contained in:
parent
d1b60269f4
commit
8f302282f4
|
@ -15,10 +15,14 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
|||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.mapper.FieldAliasMapper;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
|
@ -46,6 +50,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
||||
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
||||
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
||||
public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
|
||||
|
||||
private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1";
|
||||
|
||||
|
@ -59,6 +64,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
*/
|
||||
public static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
|
@ -70,7 +76,8 @@ public class Classification implements DataFrameAnalysis {
|
|||
(ClassAssignmentObjective) a[8],
|
||||
(Integer) a[9],
|
||||
(Double) a[10],
|
||||
(Long) a[11]));
|
||||
(Long) a[11],
|
||||
(List<PreProcessor>) a[12]));
|
||||
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
|
||||
BoostedTreeParams.declareFields(parser);
|
||||
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||
|
@ -78,6 +85,12 @@ public class Classification implements DataFrameAnalysis {
|
|||
parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
|
||||
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
|
||||
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
|
||||
parser.declareNamedObjects(optionalConstructorArg(),
|
||||
(p, c, n) -> lenient ?
|
||||
p.namedObject(LenientlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)) :
|
||||
p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)),
|
||||
(classification) -> {/*TODO should we throw if this is not set?*/},
|
||||
FEATURE_PROCESSORS);
|
||||
return parser;
|
||||
}
|
||||
|
||||
|
@ -119,6 +132,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
private final int numTopClasses;
|
||||
private final double trainingPercent;
|
||||
private final long randomizeSeed;
|
||||
private final List<PreProcessor> featureProcessors;
|
||||
|
||||
public Classification(String dependentVariable,
|
||||
BoostedTreeParams boostedTreeParams,
|
||||
|
@ -126,7 +140,8 @@ public class Classification implements DataFrameAnalysis {
|
|||
@Nullable ClassAssignmentObjective classAssignmentObjective,
|
||||
@Nullable Integer numTopClasses,
|
||||
@Nullable Double trainingPercent,
|
||||
@Nullable Long randomizeSeed) {
|
||||
@Nullable Long randomizeSeed,
|
||||
@Nullable List<PreProcessor> featureProcessors) {
|
||||
if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName());
|
||||
}
|
||||
|
@ -141,10 +156,11 @@ public class Classification implements DataFrameAnalysis {
|
|||
this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses;
|
||||
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
|
||||
this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed;
|
||||
this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors);
|
||||
}
|
||||
|
||||
public Classification(String dependentVariable) {
|
||||
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null);
|
||||
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null);
|
||||
}
|
||||
|
||||
public Classification(StreamInput in) throws IOException {
|
||||
|
@ -163,6 +179,11 @@ public class Classification implements DataFrameAnalysis {
|
|||
} else {
|
||||
randomizeSeed = Randomness.get().nextLong();
|
||||
}
|
||||
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
featureProcessors = Collections.unmodifiableList(in.readNamedWriteableList(PreProcessor.class));
|
||||
} else {
|
||||
featureProcessors = Collections.emptyList();
|
||||
}
|
||||
}
|
||||
|
||||
public String getDependentVariable() {
|
||||
|
@ -193,6 +214,10 @@ public class Classification implements DataFrameAnalysis {
|
|||
return randomizeSeed;
|
||||
}
|
||||
|
||||
public List<PreProcessor> getFeatureProcessors() {
|
||||
return featureProcessors;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
|
@ -211,6 +236,9 @@ public class Classification implements DataFrameAnalysis {
|
|||
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
|
||||
out.writeOptionalLong(randomizeSeed);
|
||||
}
|
||||
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
out.writeNamedWriteableList(featureProcessors);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -229,6 +257,9 @@ public class Classification implements DataFrameAnalysis {
|
|||
if (version.onOrAfter(Version.V_7_6_0)) {
|
||||
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
|
||||
}
|
||||
if (featureProcessors.isEmpty() == false) {
|
||||
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -249,6 +280,10 @@ public class Classification implements DataFrameAnalysis {
|
|||
}
|
||||
params.put(NUM_CLASSES, fieldInfo.getCardinality(dependentVariable));
|
||||
params.put(TRAINING_PERCENT.getPreferredName(), trainingPercent);
|
||||
if (featureProcessors.isEmpty() == false) {
|
||||
params.put(FEATURE_PROCESSORS.getPreferredName(),
|
||||
featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList()));
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
|
@ -390,6 +425,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
||||
&& Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
|
||||
&& Objects.equals(numTopClasses, that.numTopClasses)
|
||||
&& Objects.equals(featureProcessors, that.featureProcessors)
|
||||
&& trainingPercent == that.trainingPercent
|
||||
&& randomizeSeed == that.randomizeSeed;
|
||||
}
|
||||
|
@ -397,7 +433,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, classAssignmentObjective,
|
||||
numTopClasses, trainingPercent, randomizeSeed);
|
||||
numTopClasses, trainingPercent, randomizeSeed, featureProcessors);
|
||||
}
|
||||
|
||||
public enum ClassAssignmentObjective {
|
||||
|
|
|
@ -15,9 +15,13 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
|||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
|
@ -28,6 +32,7 @@ import java.util.Locale;
|
|||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
@ -42,12 +47,14 @@ public class Regression implements DataFrameAnalysis {
|
|||
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
||||
public static final ParseField LOSS_FUNCTION = new ParseField("loss_function");
|
||||
public static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter");
|
||||
public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
|
||||
|
||||
private static final String STATE_DOC_ID_SUFFIX = "_regression_state#1";
|
||||
|
||||
private static final ConstructingObjectParser<Regression, Void> LENIENT_PARSER = createParser(true);
|
||||
private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static ConstructingObjectParser<Regression, Void> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<Regression, Void> parser = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
|
@ -59,7 +66,8 @@ public class Regression implements DataFrameAnalysis {
|
|||
(Double) a[8],
|
||||
(Long) a[9],
|
||||
(LossFunction) a[10],
|
||||
(Double) a[11]));
|
||||
(Double) a[11],
|
||||
(List<PreProcessor>) a[12]));
|
||||
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
|
||||
BoostedTreeParams.declareFields(parser);
|
||||
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||
|
@ -67,6 +75,12 @@ public class Regression implements DataFrameAnalysis {
|
|||
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
|
||||
parser.declareString(optionalConstructorArg(), LossFunction::fromString, LOSS_FUNCTION);
|
||||
parser.declareDouble(optionalConstructorArg(), LOSS_FUNCTION_PARAMETER);
|
||||
parser.declareNamedObjects(optionalConstructorArg(),
|
||||
(p, c, n) -> lenient ?
|
||||
p.namedObject(LenientlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)) :
|
||||
p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)),
|
||||
(regression) -> {/*TODO should we throw if this is not set?*/},
|
||||
FEATURE_PROCESSORS);
|
||||
return parser;
|
||||
}
|
||||
|
||||
|
@ -90,6 +104,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
private final long randomizeSeed;
|
||||
private final LossFunction lossFunction;
|
||||
private final Double lossFunctionParameter;
|
||||
private final List<PreProcessor> featureProcessors;
|
||||
|
||||
public Regression(String dependentVariable,
|
||||
BoostedTreeParams boostedTreeParams,
|
||||
|
@ -97,7 +112,8 @@ public class Regression implements DataFrameAnalysis {
|
|||
@Nullable Double trainingPercent,
|
||||
@Nullable Long randomizeSeed,
|
||||
@Nullable LossFunction lossFunction,
|
||||
@Nullable Double lossFunctionParameter) {
|
||||
@Nullable Double lossFunctionParameter,
|
||||
@Nullable List<PreProcessor> featureProcessors) {
|
||||
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
|
||||
}
|
||||
|
@ -112,10 +128,11 @@ public class Regression implements DataFrameAnalysis {
|
|||
throw ExceptionsHelper.badRequestException("[{}] must be a positive double", LOSS_FUNCTION_PARAMETER.getPreferredName());
|
||||
}
|
||||
this.lossFunctionParameter = lossFunctionParameter;
|
||||
this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors);
|
||||
}
|
||||
|
||||
public Regression(String dependentVariable) {
|
||||
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null);
|
||||
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null);
|
||||
}
|
||||
|
||||
public Regression(StreamInput in) throws IOException {
|
||||
|
@ -136,6 +153,11 @@ public class Regression implements DataFrameAnalysis {
|
|||
lossFunction = LossFunction.MSE;
|
||||
lossFunctionParameter = null;
|
||||
}
|
||||
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
featureProcessors = Collections.unmodifiableList(in.readNamedWriteableList(PreProcessor.class));
|
||||
} else {
|
||||
featureProcessors = Collections.emptyList();
|
||||
}
|
||||
}
|
||||
|
||||
public String getDependentVariable() {
|
||||
|
@ -166,6 +188,10 @@ public class Regression implements DataFrameAnalysis {
|
|||
return lossFunctionParameter;
|
||||
}
|
||||
|
||||
public List<PreProcessor> getFeatureProcessors() {
|
||||
return featureProcessors;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
|
@ -184,6 +210,9 @@ public class Regression implements DataFrameAnalysis {
|
|||
out.writeEnum(lossFunction);
|
||||
out.writeOptionalDouble(lossFunctionParameter);
|
||||
}
|
||||
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
out.writeNamedWriteableList(featureProcessors);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -204,6 +233,9 @@ public class Regression implements DataFrameAnalysis {
|
|||
if (lossFunctionParameter != null) {
|
||||
builder.field(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
|
||||
}
|
||||
if (featureProcessors.isEmpty() == false) {
|
||||
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -221,6 +253,10 @@ public class Regression implements DataFrameAnalysis {
|
|||
if (lossFunctionParameter != null) {
|
||||
params.put(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
|
||||
}
|
||||
if (featureProcessors.isEmpty() == false) {
|
||||
params.put(FEATURE_PROCESSORS.getPreferredName(),
|
||||
featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList()));
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
|
@ -304,13 +340,14 @@ public class Regression implements DataFrameAnalysis {
|
|||
&& trainingPercent == that.trainingPercent
|
||||
&& randomizeSeed == that.randomizeSeed
|
||||
&& lossFunction == that.lossFunction
|
||||
&& Objects.equals(featureProcessors, that.featureProcessors)
|
||||
&& Objects.equals(lossFunctionParameter, that.lossFunctionParameter);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction,
|
||||
lossFunctionParameter);
|
||||
lossFunctionParameter, featureProcessors);
|
||||
}
|
||||
|
||||
public enum LossFunction {
|
||||
|
|
|
@ -57,23 +57,23 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
|
|||
|
||||
// PreProcessing Lenient
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, OneHotEncoding.NAME,
|
||||
OneHotEncoding::fromXContentLenient));
|
||||
(p, c) -> OneHotEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, TargetMeanEncoding.NAME,
|
||||
TargetMeanEncoding::fromXContentLenient));
|
||||
(p, c) -> TargetMeanEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, FrequencyEncoding.NAME,
|
||||
FrequencyEncoding::fromXContentLenient));
|
||||
(p, c) -> FrequencyEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, CustomWordEmbedding.NAME,
|
||||
CustomWordEmbedding::fromXContentLenient));
|
||||
(p, c) -> CustomWordEmbedding.fromXContentLenient(p)));
|
||||
|
||||
// PreProcessing Strict
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, OneHotEncoding.NAME,
|
||||
OneHotEncoding::fromXContentStrict));
|
||||
(p, c) -> OneHotEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, TargetMeanEncoding.NAME,
|
||||
TargetMeanEncoding::fromXContentStrict));
|
||||
(p, c) -> TargetMeanEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, FrequencyEncoding.NAME,
|
||||
FrequencyEncoding::fromXContentStrict));
|
||||
(p, c) -> FrequencyEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, CustomWordEmbedding.NAME,
|
||||
CustomWordEmbedding::fromXContentStrict));
|
||||
(p, c) -> CustomWordEmbedding.fromXContentStrict(p)));
|
||||
|
||||
// Model Lenient
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentLenient));
|
||||
|
|
|
@ -56,8 +56,8 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable, Acco
|
|||
TRAINED_MODEL);
|
||||
parser.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors,
|
||||
(p, c, n) -> ignoreUnknownFields ?
|
||||
p.namedObject(LenientlyParsedPreProcessor.class, n, null) :
|
||||
p.namedObject(StrictlyParsedPreProcessor.class, n, null),
|
||||
p.namedObject(LenientlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT) :
|
||||
p.namedObject(StrictlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT),
|
||||
(trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true),
|
||||
PREPROCESSORS);
|
||||
return parser;
|
||||
|
|
|
@ -50,15 +50,15 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
|
|||
public static final ParseField EMBEDDING_WEIGHTS = new ParseField("embedding_weights");
|
||||
public static final ParseField EMBEDDING_QUANT_SCALES = new ParseField("embedding_quant_scales");
|
||||
|
||||
public static final ConstructingObjectParser<CustomWordEmbedding, Void> STRICT_PARSER = createParser(false);
|
||||
public static final ConstructingObjectParser<CustomWordEmbedding, Void> LENIENT_PARSER = createParser(true);
|
||||
private static final ConstructingObjectParser<CustomWordEmbedding, PreProcessorParseContext> STRICT_PARSER = createParser(false);
|
||||
private static final ConstructingObjectParser<CustomWordEmbedding, PreProcessorParseContext> LENIENT_PARSER = createParser(true);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static ConstructingObjectParser<CustomWordEmbedding, Void> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<CustomWordEmbedding, Void> parser = new ConstructingObjectParser<>(
|
||||
private static ConstructingObjectParser<CustomWordEmbedding, PreProcessorParseContext> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<CustomWordEmbedding, PreProcessorParseContext> parser = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
lenient,
|
||||
a -> new CustomWordEmbedding((short[][])a[0], (byte[][])a[1], (String)a[2], (String)a[3]));
|
||||
(a, c) -> new CustomWordEmbedding((short[][])a[0], (byte[][])a[1], (String)a[2], (String)a[3]));
|
||||
|
||||
parser.declareField(ConstructingObjectParser.constructorArg(),
|
||||
(p, c) -> {
|
||||
|
@ -123,11 +123,11 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
|
|||
}
|
||||
|
||||
public static CustomWordEmbedding fromXContentStrict(XContentParser parser) {
|
||||
return STRICT_PARSER.apply(parser, null);
|
||||
return STRICT_PARSER.apply(parser, PreProcessorParseContext.DEFAULT);
|
||||
}
|
||||
|
||||
public static CustomWordEmbedding fromXContentLenient(XContentParser parser) {
|
||||
return LENIENT_PARSER.apply(parser, null);
|
||||
return LENIENT_PARSER.apply(parser, PreProcessorParseContext.DEFAULT);
|
||||
}
|
||||
|
||||
private static final int CONCAT_LAYER_SIZE = 80;
|
||||
|
@ -256,6 +256,11 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
|
|||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getOutputFieldType(String outputField) {
|
||||
return "dense_vector";
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long size = SHALLOW_SIZE;
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
|||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -37,15 +38,18 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
|
|||
public static final ParseField FREQUENCY_MAP = new ParseField("frequency_map");
|
||||
public static final ParseField CUSTOM = new ParseField("custom");
|
||||
|
||||
public static final ConstructingObjectParser<FrequencyEncoding, Void> STRICT_PARSER = createParser(false);
|
||||
public static final ConstructingObjectParser<FrequencyEncoding, Void> LENIENT_PARSER = createParser(true);
|
||||
private static final ConstructingObjectParser<FrequencyEncoding, PreProcessorParseContext> STRICT_PARSER = createParser(false);
|
||||
private static final ConstructingObjectParser<FrequencyEncoding, PreProcessorParseContext> LENIENT_PARSER = createParser(true);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static ConstructingObjectParser<FrequencyEncoding, Void> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<FrequencyEncoding, Void> parser = new ConstructingObjectParser<>(
|
||||
private static ConstructingObjectParser<FrequencyEncoding, PreProcessorParseContext> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<FrequencyEncoding, PreProcessorParseContext> parser = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
lenient,
|
||||
a -> new FrequencyEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Boolean)a[3]));
|
||||
(a, c) -> new FrequencyEncoding((String)a[0],
|
||||
(String)a[1],
|
||||
(Map<String, Double>)a[2],
|
||||
a[3] == null ? c.isCustomByDefault() : (Boolean)a[3]));
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(),
|
||||
|
@ -55,12 +59,12 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
|
|||
return parser;
|
||||
}
|
||||
|
||||
public static FrequencyEncoding fromXContentStrict(XContentParser parser) {
|
||||
return STRICT_PARSER.apply(parser, null);
|
||||
public static FrequencyEncoding fromXContentStrict(XContentParser parser, PreProcessorParseContext context) {
|
||||
return STRICT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context);
|
||||
}
|
||||
|
||||
public static FrequencyEncoding fromXContentLenient(XContentParser parser) {
|
||||
return LENIENT_PARSER.apply(parser, null);
|
||||
public static FrequencyEncoding fromXContentLenient(XContentParser parser, PreProcessorParseContext context) {
|
||||
return LENIENT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context);
|
||||
}
|
||||
|
||||
private final String field;
|
||||
|
@ -117,6 +121,11 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
|
|||
return custom;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getOutputFieldType(String outputField) {
|
||||
return NumberFieldMapper.NumberType.DOUBLE.typeName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
|||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -36,27 +37,29 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
|
|||
public static final ParseField HOT_MAP = new ParseField("hot_map");
|
||||
public static final ParseField CUSTOM = new ParseField("custom");
|
||||
|
||||
public static final ConstructingObjectParser<OneHotEncoding, Void> STRICT_PARSER = createParser(false);
|
||||
public static final ConstructingObjectParser<OneHotEncoding, Void> LENIENT_PARSER = createParser(true);
|
||||
private static final ConstructingObjectParser<OneHotEncoding, PreProcessorParseContext> STRICT_PARSER = createParser(false);
|
||||
private static final ConstructingObjectParser<OneHotEncoding, PreProcessorParseContext> LENIENT_PARSER = createParser(true);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static ConstructingObjectParser<OneHotEncoding, Void> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<OneHotEncoding, Void> parser = new ConstructingObjectParser<>(
|
||||
private static ConstructingObjectParser<OneHotEncoding, PreProcessorParseContext> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<OneHotEncoding, PreProcessorParseContext> parser = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
lenient,
|
||||
a -> new OneHotEncoding((String)a[0], (Map<String, String>)a[1], (Boolean)a[2]));
|
||||
(a, c) -> new OneHotEncoding((String)a[0],
|
||||
(Map<String, String>)a[1],
|
||||
a[2] == null ? c.isCustomByDefault() : (Boolean)a[2]));
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.mapStrings(), HOT_MAP);
|
||||
parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
|
||||
return parser;
|
||||
}
|
||||
|
||||
public static OneHotEncoding fromXContentStrict(XContentParser parser) {
|
||||
return STRICT_PARSER.apply(parser, null);
|
||||
public static OneHotEncoding fromXContentStrict(XContentParser parser, PreProcessorParseContext context) {
|
||||
return STRICT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context);
|
||||
}
|
||||
|
||||
public static OneHotEncoding fromXContentLenient(XContentParser parser) {
|
||||
return LENIENT_PARSER.apply(parser, null);
|
||||
public static OneHotEncoding fromXContentLenient(XContentParser parser, PreProcessorParseContext context) {
|
||||
return LENIENT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context);
|
||||
}
|
||||
|
||||
private final String field;
|
||||
|
@ -103,6 +106,11 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
|
|||
return custom;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getOutputFieldType(String outputField) {
|
||||
return NumberFieldMapper.NumberType.INTEGER.typeName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
|
@ -124,8 +132,9 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
|
|||
if (value == null) {
|
||||
return;
|
||||
}
|
||||
final String stringValue = value.toString();
|
||||
hotMap.forEach((val, col) -> {
|
||||
int encoding = value.toString().equals(val) ? 1 : 0;
|
||||
int encoding = stringValue.equals(val) ? 1 : 0;
|
||||
fields.put(col, encoding);
|
||||
});
|
||||
}
|
||||
|
|
|
@ -18,6 +18,18 @@ import java.util.Map;
|
|||
*/
|
||||
public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accountable {
|
||||
|
||||
class PreProcessorParseContext {
|
||||
public static final PreProcessorParseContext DEFAULT = new PreProcessorParseContext(false);
|
||||
final boolean defaultIsCustomValue;
|
||||
public PreProcessorParseContext(boolean defaultIsCustomValue) {
|
||||
this.defaultIsCustomValue = defaultIsCustomValue;
|
||||
}
|
||||
|
||||
public boolean isCustomByDefault() {
|
||||
return defaultIsCustomValue;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The expected input fields
|
||||
*/
|
||||
|
@ -48,4 +60,6 @@ public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accou
|
|||
*/
|
||||
boolean isCustom();
|
||||
|
||||
String getOutputFieldType(String outputField);
|
||||
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
|||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -37,15 +38,19 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
|
|||
public static final ParseField DEFAULT_VALUE = new ParseField("default_value");
|
||||
public static final ParseField CUSTOM = new ParseField("custom");
|
||||
|
||||
public static final ConstructingObjectParser<TargetMeanEncoding, Void> STRICT_PARSER = createParser(false);
|
||||
public static final ConstructingObjectParser<TargetMeanEncoding, Void> LENIENT_PARSER = createParser(true);
|
||||
private static final ConstructingObjectParser<TargetMeanEncoding, PreProcessorParseContext> STRICT_PARSER = createParser(false);
|
||||
private static final ConstructingObjectParser<TargetMeanEncoding, PreProcessorParseContext> LENIENT_PARSER = createParser(true);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static ConstructingObjectParser<TargetMeanEncoding, Void> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<TargetMeanEncoding, Void> parser = new ConstructingObjectParser<>(
|
||||
private static ConstructingObjectParser<TargetMeanEncoding, PreProcessorParseContext> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<TargetMeanEncoding, PreProcessorParseContext> parser = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
lenient,
|
||||
a -> new TargetMeanEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Double)a[3], (Boolean)a[4]));
|
||||
(a, c) -> new TargetMeanEncoding((String)a[0],
|
||||
(String)a[1],
|
||||
(Map<String, Double>)a[2],
|
||||
(Double)a[3],
|
||||
a[4] == null ? c.isCustomByDefault() : (Boolean)a[4]));
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(),
|
||||
|
@ -56,12 +61,12 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
|
|||
return parser;
|
||||
}
|
||||
|
||||
public static TargetMeanEncoding fromXContentStrict(XContentParser parser) {
|
||||
return STRICT_PARSER.apply(parser, null);
|
||||
public static TargetMeanEncoding fromXContentStrict(XContentParser parser, PreProcessorParseContext context) {
|
||||
return STRICT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context);
|
||||
}
|
||||
|
||||
public static TargetMeanEncoding fromXContentLenient(XContentParser parser) {
|
||||
return LENIENT_PARSER.apply(parser, null);
|
||||
public static TargetMeanEncoding fromXContentLenient(XContentParser parser, PreProcessorParseContext context) {
|
||||
return LENIENT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context);
|
||||
}
|
||||
|
||||
private final String field;
|
||||
|
@ -128,6 +133,11 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
|
|||
return custom;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getOutputFieldType(String outputField) {
|
||||
return NumberFieldMapper.NumberType.DOUBLE.typeName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
|
|
|
@ -41,7 +41,7 @@ public class InferenceDefinition {
|
|||
(p, c, n) -> p.namedObject(InferenceModel.class, n, null),
|
||||
TRAINED_MODEL);
|
||||
PARSER.declareNamedObjects(InferenceDefinition.Builder::setPreProcessors,
|
||||
(p, c, n) -> p.namedObject(LenientlyParsedPreProcessor.class, n, null),
|
||||
(p, c, n) -> p.namedObject(LenientlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT),
|
||||
(trainedModelDefBuilder) -> {},
|
||||
PREPROCESSORS);
|
||||
}
|
||||
|
|
|
@ -327,12 +327,14 @@ public final class ReservedFieldNames {
|
|||
Regression.LOSS_FUNCTION_PARAMETER.getPreferredName(),
|
||||
Regression.PREDICTION_FIELD_NAME.getPreferredName(),
|
||||
Regression.TRAINING_PERCENT.getPreferredName(),
|
||||
Regression.FEATURE_PROCESSORS.getPreferredName(),
|
||||
Classification.NAME.getPreferredName(),
|
||||
Classification.DEPENDENT_VARIABLE.getPreferredName(),
|
||||
Classification.PREDICTION_FIELD_NAME.getPreferredName(),
|
||||
Classification.CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(),
|
||||
Classification.NUM_TOP_CLASSES.getPreferredName(),
|
||||
Classification.TRAINING_PERCENT.getPreferredName(),
|
||||
Classification.FEATURE_PROCESSORS.getPreferredName(),
|
||||
BoostedTreeParams.LAMBDA.getPreferredName(),
|
||||
BoostedTreeParams.GAMMA.getPreferredName(),
|
||||
BoostedTreeParams.ETA.getPreferredName(),
|
||||
|
|
|
@ -34,6 +34,9 @@
|
|||
"feature_bag_fraction" : {
|
||||
"type" : "double"
|
||||
},
|
||||
"feature_processors": {
|
||||
"enabled": false
|
||||
},
|
||||
"gamma" : {
|
||||
"type" : "double"
|
||||
},
|
||||
|
@ -84,6 +87,9 @@
|
|||
"feature_bag_fraction" : {
|
||||
"type" : "double"
|
||||
},
|
||||
"feature_processors": {
|
||||
"enabled": false
|
||||
},
|
||||
"gamma" : {
|
||||
"type" : "double"
|
||||
},
|
||||
|
|
|
@ -16,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction.Respon
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
|
@ -28,6 +29,7 @@ public class GetDataFrameAnalyticsActionResponseTests extends AbstractWireSerial
|
|||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||
namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables());
|
||||
namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
|
||||
namedWriteables.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||
return new NamedWriteableRegistry(namedWriteables);
|
||||
}
|
||||
|
||||
|
@ -36,6 +38,7 @@ public class GetDataFrameAnalyticsActionResponseTests extends AbstractWireSerial
|
|||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -44,6 +45,7 @@ public class PutDataFrameAnalyticsActionRequestTests extends AbstractSerializing
|
|||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||
namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables());
|
||||
namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
|
||||
namedWriteables.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||
return new NamedWriteableRegistry(namedWriteables);
|
||||
}
|
||||
|
||||
|
@ -52,6 +54,7 @@ public class PutDataFrameAnalyticsActionRequestTests extends AbstractSerializing
|
|||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
|||
import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction.Response;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
|
@ -25,6 +26,7 @@ public class PutDataFrameAnalyticsActionResponseTests extends AbstractWireSerial
|
|||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||
namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables());
|
||||
namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
|
||||
namedWriteables.addAll(new MlInferenceNamedXContentProvider() .getNamedWriteables());
|
||||
return new NamedWriteableRegistry(namedWriteables);
|
||||
}
|
||||
|
||||
|
|
|
@ -41,6 +41,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.RegressionTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
import org.junit.Before;
|
||||
|
||||
|
@ -78,6 +79,7 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
|
|||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||
namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables());
|
||||
namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
|
||||
namedWriteables.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||
return new NamedWriteableRegistry(namedWriteables);
|
||||
}
|
||||
|
||||
|
@ -86,6 +88,7 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
|
|||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
|
||||
|
@ -144,14 +147,16 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
|
|||
bwcRegression.getTrainingPercent(),
|
||||
42L,
|
||||
bwcRegression.getLossFunction(),
|
||||
bwcRegression.getLossFunctionParameter());
|
||||
bwcRegression.getLossFunctionParameter(),
|
||||
bwcRegression.getFeatureProcessors());
|
||||
testAnalysis = new Regression(testRegression.getDependentVariable(),
|
||||
testRegression.getBoostedTreeParams(),
|
||||
testRegression.getPredictionFieldName(),
|
||||
testRegression.getTrainingPercent(),
|
||||
42L,
|
||||
testRegression.getLossFunction(),
|
||||
testRegression.getLossFunctionParameter());
|
||||
testRegression.getLossFunctionParameter(),
|
||||
bwcRegression.getFeatureProcessors());
|
||||
} else {
|
||||
Classification testClassification = (Classification)testInstance.getAnalysis();
|
||||
Classification bwcClassification = (Classification)bwcSerializedObject.getAnalysis();
|
||||
|
@ -161,14 +166,16 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
|
|||
bwcClassification.getClassAssignmentObjective(),
|
||||
bwcClassification.getNumTopClasses(),
|
||||
bwcClassification.getTrainingPercent(),
|
||||
42L);
|
||||
42L,
|
||||
bwcClassification.getFeatureProcessors());
|
||||
testAnalysis = new Classification(testClassification.getDependentVariable(),
|
||||
testClassification.getBoostedTreeParams(),
|
||||
testClassification.getPredictionFieldName(),
|
||||
testClassification.getClassAssignmentObjective(),
|
||||
testClassification.getNumTopClasses(),
|
||||
testClassification.getTrainingPercent(),
|
||||
42L);
|
||||
42L,
|
||||
testClassification.getFeatureProcessors());
|
||||
}
|
||||
super.assertOnBWCObject(new DataFrameAnalyticsConfig.Builder(bwcSerializedObject)
|
||||
.setAnalysis(bwcAnalysis)
|
||||
|
|
|
@ -8,25 +8,41 @@ package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
|||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.DeprecationHandler;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.index.mapper.BooleanFieldMapper;
|
||||
import org.elasticsearch.index.mapper.KeywordFieldMapper;
|
||||
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.Matchers.allOf;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
|
@ -55,6 +71,21 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
|
||||
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||
return new NamedWriteableRegistry(entries);
|
||||
}
|
||||
|
||||
public static Classification createRandom() {
|
||||
String dependentVariableName = randomAlphaOfLength(10);
|
||||
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
|
||||
|
@ -65,7 +96,14 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
|
||||
Long randomizeSeed = randomBoolean() ? null : randomLong();
|
||||
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
|
||||
numTopClasses, trainingPercent, randomizeSeed);
|
||||
numTopClasses, trainingPercent, randomizeSeed,
|
||||
randomBoolean() ?
|
||||
null :
|
||||
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(true),
|
||||
OneHotEncodingTests.createRandom(true),
|
||||
TargetMeanEncodingTests.createRandom(true)))
|
||||
.limit(randomIntBetween(0, 5))
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
public static Classification mutateForVersion(Classification instance, Version version) {
|
||||
|
@ -75,7 +113,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
version.onOrAfter(Version.V_7_7_0) ? instance.getClassAssignmentObjective() : null,
|
||||
instance.getNumTopClasses(),
|
||||
instance.getTrainingPercent(),
|
||||
instance.getRandomizeSeed());
|
||||
instance.getRandomizeSeed(),
|
||||
version.onOrAfter(Version.V_7_10_0) ? instance.getFeatureProcessors() : Collections.emptyList());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -91,14 +130,16 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
bwcSerializedObject.getClassAssignmentObjective(),
|
||||
bwcSerializedObject.getNumTopClasses(),
|
||||
bwcSerializedObject.getTrainingPercent(),
|
||||
42L);
|
||||
42L,
|
||||
bwcSerializedObject.getFeatureProcessors());
|
||||
Classification newInstance = new Classification(testInstance.getDependentVariable(),
|
||||
testInstance.getBoostedTreeParams(),
|
||||
testInstance.getPredictionFieldName(),
|
||||
testInstance.getClassAssignmentObjective(),
|
||||
testInstance.getNumTopClasses(),
|
||||
testInstance.getTrainingPercent(),
|
||||
42L);
|
||||
42L,
|
||||
testInstance.getFeatureProcessors());
|
||||
super.assertOnBWCObject(newBwc, newInstance, version);
|
||||
}
|
||||
|
||||
|
@ -107,87 +148,138 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
return Classification::new;
|
||||
}
|
||||
|
||||
public void testDeserialization() throws IOException {
|
||||
String toDeserialize = "{\n" +
|
||||
" \"dependent_variable\": \"FlightDelayMin\",\n" +
|
||||
" \"feature_processors\": [\n" +
|
||||
" {\n" +
|
||||
" \"one_hot_encoding\": {\n" +
|
||||
" \"field\": \"OriginWeather\",\n" +
|
||||
" \"hot_map\": {\n" +
|
||||
" \"sunny_col\": \"Sunny\",\n" +
|
||||
" \"clear_col\": \"Clear\",\n" +
|
||||
" \"rainy_col\": \"Rain\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"one_hot_encoding\": {\n" +
|
||||
" \"field\": \"DestWeather\",\n" +
|
||||
" \"hot_map\": {\n" +
|
||||
" \"dest_sunny_col\": \"Sunny\",\n" +
|
||||
" \"dest_clear_col\": \"Clear\",\n" +
|
||||
" \"dest_rainy_col\": \"Rain\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"frequency_encoding\": {\n" +
|
||||
" \"field\": \"OriginWeather\",\n" +
|
||||
" \"feature_name\": \"mean\",\n" +
|
||||
" \"frequency_map\": {\n" +
|
||||
" \"Sunny\": 0.8,\n" +
|
||||
" \"Rain\": 0.2\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" ]\n" +
|
||||
" }" +
|
||||
"";
|
||||
|
||||
try(XContentParser parser = XContentHelper.createParser(xContentRegistry(),
|
||||
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
|
||||
new BytesArray(toDeserialize),
|
||||
XContentType.JSON)) {
|
||||
Classification parsed = Classification.fromXContent(parser, false);
|
||||
assertThat(parsed.getDependentVariable(), equalTo("FlightDelayMin"));
|
||||
for (PreProcessor preProcessor : parsed.getFeatureProcessors()) {
|
||||
assertThat(preProcessor.isCustom(), is(true));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.999, randomLong()));
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.999, randomLong(), null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong()));
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong(), null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenNumTopClassesIsLessThanZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong()));
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong()));
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
||||
}
|
||||
|
||||
public void testGetPredictionFieldName() {
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong());
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null);
|
||||
assertThat(classification.getPredictionFieldName(), equalTo("result"));
|
||||
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong());
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong(), null);
|
||||
assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction"));
|
||||
}
|
||||
|
||||
public void testClassAssignmentObjective() {
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
|
||||
Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong());
|
||||
Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong(), null);
|
||||
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY));
|
||||
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
|
||||
Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong());
|
||||
Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong(), null);
|
||||
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
|
||||
|
||||
// class_assignment_objective == null, default applied
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong());
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
|
||||
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
|
||||
}
|
||||
|
||||
public void testGetNumTopClasses() {
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong());
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
|
||||
assertThat(classification.getNumTopClasses(), equalTo(7));
|
||||
|
||||
// Boundary condition: num_top_classes == 0
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong());
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong(), null);
|
||||
assertThat(classification.getNumTopClasses(), equalTo(0));
|
||||
|
||||
// Boundary condition: num_top_classes == 1000
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong());
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong(), null);
|
||||
assertThat(classification.getNumTopClasses(), equalTo(1000));
|
||||
|
||||
// num_top_classes == null, default applied
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong());
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong(), null);
|
||||
assertThat(classification.getNumTopClasses(), equalTo(2));
|
||||
}
|
||||
|
||||
public void testGetTrainingPercent() {
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong());
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null);
|
||||
assertThat(classification.getTrainingPercent(), equalTo(50.0));
|
||||
|
||||
// Boundary condition: training_percent == 1.0
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong());
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong(), null);
|
||||
assertThat(classification.getTrainingPercent(), equalTo(1.0));
|
||||
|
||||
// Boundary condition: training_percent == 100.0
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong());
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong(), null);
|
||||
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
||||
|
||||
// training_percent == null, default applied
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong());
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong(), null);
|
||||
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
||||
}
|
||||
|
||||
|
@ -233,6 +325,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
null,
|
||||
null,
|
||||
50.0,
|
||||
null,
|
||||
null).getParams(fieldInfo),
|
||||
equalTo(
|
||||
org.elasticsearch.common.collect.Map.of(
|
||||
|
|
|
@ -8,18 +8,35 @@ package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
|||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.DeprecationHandler;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
|
@ -45,6 +62,21 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
|||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
|
||||
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||
return new NamedWriteableRegistry(entries);
|
||||
}
|
||||
|
||||
public static Regression createRandom() {
|
||||
return createRandom(BoostedTreeParamsTests.createRandom());
|
||||
}
|
||||
|
@ -57,7 +89,14 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
|||
Regression.LossFunction lossFunction = randomBoolean() ? null : randomFrom(Regression.LossFunction.values());
|
||||
Double lossFunctionParameter = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, false);
|
||||
return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction,
|
||||
lossFunctionParameter);
|
||||
lossFunctionParameter,
|
||||
randomBoolean() ?
|
||||
null :
|
||||
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(true),
|
||||
OneHotEncodingTests.createRandom(true),
|
||||
TargetMeanEncodingTests.createRandom(true)))
|
||||
.limit(randomIntBetween(0, 5))
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
public static Regression mutateForVersion(Regression instance, Version version) {
|
||||
|
@ -67,7 +106,8 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
|||
instance.getTrainingPercent(),
|
||||
instance.getRandomizeSeed(),
|
||||
version.onOrAfter(Version.V_7_8_0) ? instance.getLossFunction() : null,
|
||||
version.onOrAfter(Version.V_7_8_0) ? instance.getLossFunctionParameter() : null);
|
||||
version.onOrAfter(Version.V_7_8_0) ? instance.getLossFunctionParameter() : null,
|
||||
version.onOrAfter(Version.V_7_10_0) ? instance.getFeatureProcessors() : Collections.emptyList());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -83,14 +123,16 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
|||
bwcSerializedObject.getTrainingPercent(),
|
||||
42L,
|
||||
bwcSerializedObject.getLossFunction(),
|
||||
bwcSerializedObject.getLossFunctionParameter());
|
||||
bwcSerializedObject.getLossFunctionParameter(),
|
||||
bwcSerializedObject.getFeatureProcessors());
|
||||
Regression newInstance = new Regression(testInstance.getDependentVariable(),
|
||||
testInstance.getBoostedTreeParams(),
|
||||
testInstance.getPredictionFieldName(),
|
||||
testInstance.getTrainingPercent(),
|
||||
42L,
|
||||
testInstance.getLossFunction(),
|
||||
testInstance.getLossFunctionParameter());
|
||||
testInstance.getLossFunctionParameter(),
|
||||
testInstance.getFeatureProcessors());
|
||||
super.assertOnBWCObject(newBwc, newInstance, version);
|
||||
}
|
||||
|
||||
|
@ -104,56 +146,122 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
|||
return Regression::new;
|
||||
}
|
||||
|
||||
public void testDeserialization() throws IOException {
|
||||
String toDeserialize = "{\n" +
|
||||
" \"dependent_variable\": \"FlightDelayMin\",\n" +
|
||||
" \"feature_processors\": [\n" +
|
||||
" {\n" +
|
||||
" \"one_hot_encoding\": {\n" +
|
||||
" \"field\": \"OriginWeather\",\n" +
|
||||
" \"hot_map\": {\n" +
|
||||
" \"sunny_col\": \"Sunny\",\n" +
|
||||
" \"clear_col\": \"Clear\",\n" +
|
||||
" \"rainy_col\": \"Rain\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"one_hot_encoding\": {\n" +
|
||||
" \"field\": \"DestWeather\",\n" +
|
||||
" \"hot_map\": {\n" +
|
||||
" \"dest_sunny_col\": \"Sunny\",\n" +
|
||||
" \"dest_clear_col\": \"Clear\",\n" +
|
||||
" \"dest_rainy_col\": \"Rain\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"frequency_encoding\": {\n" +
|
||||
" \"field\": \"OriginWeather\",\n" +
|
||||
" \"feature_name\": \"mean\",\n" +
|
||||
" \"frequency_map\": {\n" +
|
||||
" \"Sunny\": 0.8,\n" +
|
||||
" \"Rain\": 0.2\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" ]\n" +
|
||||
" }" +
|
||||
"";
|
||||
|
||||
try(XContentParser parser = XContentHelper.createParser(xContentRegistry(),
|
||||
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
|
||||
new BytesArray(toDeserialize),
|
||||
XContentType.JSON)) {
|
||||
Regression parsed = Regression.fromXContent(parser, false);
|
||||
assertThat(parsed.getDependentVariable(), equalTo("FlightDelayMin"));
|
||||
for (PreProcessor preProcessor : parsed.getFeatureProcessors()) {
|
||||
assertThat(preProcessor.isCustom(), is(true));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999, randomLong(), Regression.LossFunction.MSE, null));
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999, randomLong(), Regression.LossFunction.MSE, null, null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(), Regression.LossFunction.MSE, null));
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(), Regression.LossFunction.MSE, null, null));
|
||||
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenLossFunctionParameterIsZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, 0.0));
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, 0.0, null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenLossFunctionParameterIsNegative() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, -1.0));
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, -1.0, null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double"));
|
||||
}
|
||||
|
||||
public void testGetPredictionFieldName() {
|
||||
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong(), Regression.LossFunction.MSE, 1.0);
|
||||
Regression regression = new Regression(
|
||||
"foo",
|
||||
BOOSTED_TREE_PARAMS,
|
||||
"result",
|
||||
50.0,
|
||||
randomLong(),
|
||||
Regression.LossFunction.MSE,
|
||||
1.0,
|
||||
null);
|
||||
assertThat(regression.getPredictionFieldName(), equalTo("result"));
|
||||
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong(), Regression.LossFunction.MSE, null);
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong(), Regression.LossFunction.MSE, null, null);
|
||||
assertThat(regression.getPredictionFieldName(), equalTo("foo_prediction"));
|
||||
}
|
||||
|
||||
public void testGetTrainingPercent() {
|
||||
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong(), Regression.LossFunction.MSE, 1.0);
|
||||
Regression regression = new Regression("foo",
|
||||
BOOSTED_TREE_PARAMS,
|
||||
"result",
|
||||
50.0,
|
||||
randomLong(),
|
||||
Regression.LossFunction.MSE,
|
||||
1.0,
|
||||
null);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(50.0));
|
||||
|
||||
// Boundary condition: training_percent == 1.0
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong(), Regression.LossFunction.MSE, null);
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong(), Regression.LossFunction.MSE, null, null);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(1.0));
|
||||
|
||||
// Boundary condition: training_percent == 100.0
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, null);
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, null, null);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(100.0));
|
||||
|
||||
// training_percent == null, default applied
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong(), Regression.LossFunction.MSE, null);
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong(), Regression.LossFunction.MSE, null, null);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(100.0));
|
||||
}
|
||||
|
||||
|
@ -165,6 +273,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
|||
100.0,
|
||||
0L,
|
||||
Regression.LossFunction.MSE,
|
||||
null,
|
||||
null);
|
||||
|
||||
Map<String, Object> params = regression.getParams(null);
|
||||
|
@ -182,7 +291,9 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
|||
|
||||
Map<String, Object> params = regression.getParams(null);
|
||||
|
||||
int expectedParamsCount = 4 + (regression.getLossFunctionParameter() == null ? 0 : 1);
|
||||
int expectedParamsCount = 4
|
||||
+ (regression.getLossFunctionParameter() == null ? 0 : 1)
|
||||
+ (regression.getFeatureProcessors().isEmpty() ? 0 : 1);
|
||||
assertThat(params.size(), equalTo(expectedParamsCount));
|
||||
assertThat(params.get("dependent_variable"), equalTo(regression.getDependentVariable()));
|
||||
assertThat(params.get("prediction_field_name"), equalTo(regression.getPredictionFieldName()));
|
||||
|
|
|
@ -24,7 +24,9 @@ public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding
|
|||
|
||||
@Override
|
||||
protected FrequencyEncoding doParseInstance(XContentParser parser) throws IOException {
|
||||
return lenient ? FrequencyEncoding.fromXContentLenient(parser) : FrequencyEncoding.fromXContentStrict(parser);
|
||||
return lenient ?
|
||||
FrequencyEncoding.fromXContentLenient(parser, PreProcessor.PreProcessorParseContext.DEFAULT) :
|
||||
FrequencyEncoding.fromXContentStrict(parser, PreProcessor.PreProcessorParseContext.DEFAULT);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -33,6 +35,10 @@ public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding
|
|||
}
|
||||
|
||||
public static FrequencyEncoding createRandom() {
|
||||
return createRandom(randomBoolean() ? null : randomBoolean());
|
||||
}
|
||||
|
||||
public static FrequencyEncoding createRandom(Boolean isCustom) {
|
||||
int valuesSize = randomIntBetween(1, 10);
|
||||
Map<String, Double> valueMap = new HashMap<>();
|
||||
for (int i = 0; i < valuesSize; i++) {
|
||||
|
@ -41,7 +47,7 @@ public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding
|
|||
return new FrequencyEncoding(randomAlphaOfLength(10),
|
||||
randomAlphaOfLength(10),
|
||||
valueMap,
|
||||
randomBoolean() ? null : randomBoolean());
|
||||
isCustom);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -24,7 +24,9 @@ public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
|
|||
|
||||
@Override
|
||||
protected OneHotEncoding doParseInstance(XContentParser parser) throws IOException {
|
||||
return lenient ? OneHotEncoding.fromXContentLenient(parser) : OneHotEncoding.fromXContentStrict(parser);
|
||||
return lenient ?
|
||||
OneHotEncoding.fromXContentLenient(parser, PreProcessor.PreProcessorParseContext.DEFAULT) :
|
||||
OneHotEncoding.fromXContentStrict(parser, PreProcessor.PreProcessorParseContext.DEFAULT);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -33,6 +35,10 @@ public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
|
|||
}
|
||||
|
||||
public static OneHotEncoding createRandom() {
|
||||
return createRandom(randomBoolean() ? randomBoolean() : null);
|
||||
}
|
||||
|
||||
public static OneHotEncoding createRandom(Boolean isCustom) {
|
||||
int valuesSize = randomIntBetween(1, 10);
|
||||
Map<String, String> valueMap = new HashMap<>();
|
||||
for (int i = 0; i < valuesSize; i++) {
|
||||
|
@ -40,7 +46,7 @@ public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
|
|||
}
|
||||
return new OneHotEncoding(randomAlphaOfLength(10),
|
||||
valueMap,
|
||||
randomBoolean() ? randomBoolean() : null);
|
||||
isCustom);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -24,7 +24,9 @@ public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncodi
|
|||
|
||||
@Override
|
||||
protected TargetMeanEncoding doParseInstance(XContentParser parser) throws IOException {
|
||||
return lenient ? TargetMeanEncoding.fromXContentLenient(parser) : TargetMeanEncoding.fromXContentStrict(parser);
|
||||
return lenient ?
|
||||
TargetMeanEncoding.fromXContentLenient(parser, PreProcessor.PreProcessorParseContext.DEFAULT) :
|
||||
TargetMeanEncoding.fromXContentStrict(parser, PreProcessor.PreProcessorParseContext.DEFAULT);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -32,7 +34,12 @@ public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncodi
|
|||
return createRandom();
|
||||
}
|
||||
|
||||
|
||||
public static TargetMeanEncoding createRandom() {
|
||||
return createRandom(randomBoolean() ? randomBoolean() : null);
|
||||
}
|
||||
|
||||
public static TargetMeanEncoding createRandom(Boolean isCustom) {
|
||||
int valuesSize = randomIntBetween(1, 10);
|
||||
Map<String, Double> valueMap = new HashMap<>();
|
||||
for (int i = 0; i < valuesSize; i++) {
|
||||
|
@ -42,7 +49,7 @@ public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncodi
|
|||
randomAlphaOfLength(10),
|
||||
valueMap,
|
||||
randomDoubleBetween(0.0, 1.0, false),
|
||||
randomBoolean() ? randomBoolean() : null);
|
||||
isCustom);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -21,22 +21,30 @@ import org.elasticsearch.action.support.WriteRequest;
|
|||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigUpdate;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
|
@ -108,6 +116,15 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList());
|
||||
List<NamedXContentRegistry.Entry> entries = new ArrayList<>(searchModule.getNamedXContents());
|
||||
entries.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
entries.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
||||
return new NamedXContentRegistry(entries);
|
||||
}
|
||||
|
||||
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
||||
initialize("classification_single_numeric_feature_and_mixed_data_set");
|
||||
String predictedClassField = KEYWORD_FIELD + "_prediction";
|
||||
|
@ -121,6 +138,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null));
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -176,6 +194,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null));
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -268,6 +287,76 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
|
||||
}
|
||||
|
||||
public void testWithCustomFeatureProcessors() throws Exception {
|
||||
initialize("classification_with_custom_feature_processors");
|
||||
String predictedClassField = KEYWORD_FIELD + "_prediction";
|
||||
indexData(sourceIndex, 300, 50, KEYWORD_FIELD);
|
||||
|
||||
DataFrameAnalyticsConfig config =
|
||||
buildAnalytics(jobId, sourceIndex, destIndex, null,
|
||||
new Classification(
|
||||
KEYWORD_FIELD,
|
||||
BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(),
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
Arrays.asList(
|
||||
new OneHotEncoding(TEXT_FIELD, Collections.singletonMap(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom"), true)
|
||||
)));
|
||||
putAnalytics(config);
|
||||
|
||||
assertIsStopped(jobId);
|
||||
assertProgressIsZero(jobId);
|
||||
|
||||
startAnalytics(jobId);
|
||||
waitUntilAnalyticsIsStopped(jobId);
|
||||
|
||||
client().admin().indices().refresh(new RefreshRequest(destIndex));
|
||||
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
||||
for (SearchHit hit : sourceData.getHits()) {
|
||||
Map<String, Object> destDoc = getDestDoc(config, hit);
|
||||
Map<String, Object> resultsObject = getFieldValue(destDoc, "ml");
|
||||
assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES)));
|
||||
assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
|
||||
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
|
||||
@SuppressWarnings("unchecked")
|
||||
List<Map<String, Object>> importanceArray = (List<Map<String, Object>>)resultsObject.get("feature_importance");
|
||||
assertThat(importanceArray, hasSize(greaterThan(0)));
|
||||
}
|
||||
|
||||
assertProgressComplete(jobId);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(stateDocId());
|
||||
assertInferenceModelPersisted(jobId);
|
||||
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
|
||||
assertThatAuditMessagesMatch(jobId,
|
||||
"Created analytics with analysis type [classification]",
|
||||
"Estimated memory usage for this analytics to be",
|
||||
"Starting analytics on node",
|
||||
"Started analytics",
|
||||
expectedDestIndexAuditMessage(),
|
||||
"Started reindexing to destination index [" + destIndex + "]",
|
||||
"Finished reindexing to destination index [" + destIndex + "]",
|
||||
"Started loading data",
|
||||
"Started analyzing",
|
||||
"Started writing results",
|
||||
"Finished analysis");
|
||||
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
|
||||
|
||||
GetTrainedModelsAction.Response response = client().execute(GetTrainedModelsAction.INSTANCE,
|
||||
new GetTrainedModelsAction.Request(jobId + "*", true, Collections.emptyList())).actionGet();
|
||||
assertThat(response.getResources().results().size(), equalTo(1));
|
||||
TrainedModelConfig modelConfig = response.getResources().results().get(0);
|
||||
modelConfig.ensureParsedDefinition(xContentRegistry());
|
||||
assertThat(modelConfig.getModelDefinition().getPreProcessors().size(), greaterThan(0));
|
||||
for (int i = 0; i < modelConfig.getModelDefinition().getPreProcessors().size(); i++) {
|
||||
PreProcessor preProcessor = modelConfig.getModelDefinition().getPreProcessors().get(i);
|
||||
assertThat(preProcessor.isCustom(), equalTo(i == 0));
|
||||
}
|
||||
}
|
||||
|
||||
public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(String jobId,
|
||||
String dependentVariable,
|
||||
List<T> dependentVariableValues,
|
||||
|
@ -283,7 +372,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
sourceIndex,
|
||||
destIndex,
|
||||
null,
|
||||
new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, null, numTopClasses, 50.0, null));
|
||||
new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, null, numTopClasses, 50.0, null, null));
|
||||
putAnalytics(config);
|
||||
|
||||
assertIsStopped(jobId);
|
||||
|
@ -352,7 +441,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
"classification_training_percent_is_50_integer", DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES, "integer");
|
||||
}
|
||||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsDouble() throws Exception {
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsDouble() {
|
||||
ElasticsearchStatusException e = expectThrows(
|
||||
ElasticsearchStatusException.class,
|
||||
() -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
|
||||
|
@ -360,7 +449,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertThat(e.getMessage(), startsWith("invalid types [double] for required field [numerical-field];"));
|
||||
}
|
||||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsText() throws Exception {
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsText() {
|
||||
ElasticsearchStatusException e = expectThrows(
|
||||
ElasticsearchStatusException.class,
|
||||
() -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
|
||||
|
@ -549,7 +638,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
.build();
|
||||
|
||||
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
|
||||
new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, null));
|
||||
new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, null, null));
|
||||
putAnalytics(firstJob);
|
||||
|
||||
String secondJobId = "classification_two_jobs_with_same_randomize_seed_2";
|
||||
|
@ -557,7 +646,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
long randomizeSeed = ((Classification) firstJob.getAnalysis()).getRandomizeSeed();
|
||||
DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
|
||||
new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, randomizeSeed));
|
||||
new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, randomizeSeed, null));
|
||||
|
||||
putAnalytics(secondJob);
|
||||
|
||||
|
|
|
@ -104,6 +104,7 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
|
|||
100.0,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null))
|
||||
.buildForExplain();
|
||||
|
||||
|
@ -122,6 +123,7 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
|
|||
50.0,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null))
|
||||
.buildForExplain();
|
||||
|
||||
|
@ -149,6 +151,7 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
|
|||
100.0,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null))
|
||||
.buildForExplain();
|
||||
|
||||
|
|
|
@ -14,24 +14,36 @@ import org.elasticsearch.action.get.GetResponse;
|
|||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
import org.junit.After;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.Arrays;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
@ -65,6 +77,15 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
cleanUp();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList());
|
||||
List<NamedXContentRegistry.Entry> entries = new ArrayList<>(searchModule.getNamedXContents());
|
||||
entries.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
entries.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
||||
return new NamedXContentRegistry(entries);
|
||||
}
|
||||
|
||||
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/59413")
|
||||
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
||||
initialize("regression_single_numeric_feature_and_mixed_data_set");
|
||||
|
@ -79,6 +100,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null)
|
||||
);
|
||||
putAnalytics(config);
|
||||
|
@ -192,7 +214,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
sourceIndex,
|
||||
destIndex,
|
||||
null,
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null, null, null));
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null, null, null, null));
|
||||
putAnalytics(config);
|
||||
|
||||
assertIsStopped(jobId);
|
||||
|
@ -319,7 +341,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
.build();
|
||||
|
||||
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null, null, null));
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null, null, null, null));
|
||||
putAnalytics(firstJob);
|
||||
|
||||
String secondJobId = "regression_two_jobs_with_same_randomize_seed_2";
|
||||
|
@ -327,7 +349,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
long randomizeSeed = ((Regression) firstJob.getAnalysis()).getRandomizeSeed();
|
||||
DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, randomizeSeed, null, null));
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, randomizeSeed, null, null, null));
|
||||
|
||||
putAnalytics(secondJob);
|
||||
|
||||
|
@ -388,7 +410,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
sourceIndex,
|
||||
destIndex,
|
||||
null,
|
||||
new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null, null, null));
|
||||
new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null, null, null, null));
|
||||
putAnalytics(config);
|
||||
|
||||
assertIsStopped(jobId);
|
||||
|
@ -415,6 +437,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null)
|
||||
);
|
||||
putAnalytics(config);
|
||||
|
@ -511,6 +534,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
90.0,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null);
|
||||
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
|
||||
.setId(jobId)
|
||||
|
@ -566,6 +590,73 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
"Finished analysis");
|
||||
}
|
||||
|
||||
public void testWithCustomFeatureProcessors() throws Exception {
|
||||
initialize("regression_with_custom_feature_processors");
|
||||
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
|
||||
indexData(sourceIndex, 300, 50);
|
||||
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null,
|
||||
new Regression(
|
||||
DEPENDENT_VARIABLE_FIELD,
|
||||
BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(),
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
Arrays.asList(
|
||||
new OneHotEncoding(DISCRETE_NUMERICAL_FEATURE_FIELD,
|
||||
Collections.singletonMap(DISCRETE_NUMERICAL_FEATURE_VALUES.get(0).toString(), "tenner"), true)
|
||||
))
|
||||
);
|
||||
putAnalytics(config);
|
||||
|
||||
assertIsStopped(jobId);
|
||||
assertProgressIsZero(jobId);
|
||||
|
||||
startAnalytics(jobId);
|
||||
waitUntilAnalyticsIsStopped(jobId);
|
||||
|
||||
// for debugging
|
||||
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
||||
for (SearchHit hit : sourceData.getHits()) {
|
||||
Map<String, Object> destDoc = getDestDoc(config, hit);
|
||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
|
||||
|
||||
assertThat(resultsObject.containsKey(predictedClassField), is(true));
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
|
||||
}
|
||||
|
||||
assertProgressComplete(jobId);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(stateDocId());
|
||||
assertInferenceModelPersisted(jobId);
|
||||
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
|
||||
assertThatAuditMessagesMatch(jobId,
|
||||
"Created analytics with analysis type [regression]",
|
||||
"Estimated memory usage for this analytics to be",
|
||||
"Starting analytics on node",
|
||||
"Started analytics",
|
||||
"Creating destination index [" + destIndex + "]",
|
||||
"Started reindexing to destination index [" + destIndex + "]",
|
||||
"Finished reindexing to destination index [" + destIndex + "]",
|
||||
"Started loading data",
|
||||
"Started analyzing",
|
||||
"Started writing results",
|
||||
"Finished analysis");
|
||||
GetTrainedModelsAction.Response response = client().execute(GetTrainedModelsAction.INSTANCE,
|
||||
new GetTrainedModelsAction.Request(jobId + "*", true, Collections.emptyList())).actionGet();
|
||||
assertThat(response.getResources().results().size(), equalTo(1));
|
||||
TrainedModelConfig modelConfig = response.getResources().results().get(0);
|
||||
modelConfig.ensureParsedDefinition(xContentRegistry());
|
||||
assertThat(modelConfig.getModelDefinition().getPreProcessors().size(), greaterThan(0));
|
||||
for (int i = 0; i < modelConfig.getModelDefinition().getPreProcessors().size(); i++) {
|
||||
PreProcessor preProcessor = modelConfig.getModelDefinition().getPreProcessors().get(i);
|
||||
assertThat(preProcessor.isCustom(), equalTo(i == 0));
|
||||
}
|
||||
}
|
||||
|
||||
private void initialize(String jobId) {
|
||||
this.jobId = jobId;
|
||||
this.sourceIndex = jobId + "_source_index";
|
||||
|
|
|
@ -71,7 +71,7 @@ public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
|
|||
analyticsConfig,
|
||||
new DataFrameAnalyticsAuditor(client(), "test-node"),
|
||||
(ex) -> { throw new ElasticsearchException(ex); },
|
||||
new ExtractedFields(extractedFieldList, Collections.emptyMap())
|
||||
new ExtractedFields(extractedFieldList, Collections.emptyList(), Collections.emptyMap())
|
||||
);
|
||||
|
||||
//Accuracy for size is not tested here
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigUpdate;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
|
||||
import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
|
||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||
|
@ -171,9 +172,9 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
|||
blockingCall(
|
||||
actionListener -> configProvider.put(initialConfig, emptyMap(), actionListener), configHolder, exceptionHolder);
|
||||
|
||||
assertNoException(exceptionHolder);
|
||||
assertThat(configHolder.get(), is(notNullValue()));
|
||||
assertThat(configHolder.get(), is(equalTo(initialConfig)));
|
||||
assertThat(exceptionHolder.get(), is(nullValue()));
|
||||
}
|
||||
{ // Update that changes description
|
||||
AtomicReference<DataFrameAnalyticsConfig> updatedConfigHolder = new AtomicReference<>();
|
||||
|
@ -188,7 +189,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
|||
actionListener -> configProvider.update(configUpdate, emptyMap(), ClusterState.EMPTY_STATE, actionListener),
|
||||
updatedConfigHolder,
|
||||
exceptionHolder);
|
||||
|
||||
assertNoException(exceptionHolder);
|
||||
assertThat(updatedConfigHolder.get(), is(notNullValue()));
|
||||
assertThat(
|
||||
updatedConfigHolder.get(),
|
||||
|
@ -196,7 +197,6 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
|||
new DataFrameAnalyticsConfig.Builder(initialConfig)
|
||||
.setDescription("description-1")
|
||||
.build())));
|
||||
assertThat(exceptionHolder.get(), is(nullValue()));
|
||||
}
|
||||
{ // Update that changes model memory limit
|
||||
AtomicReference<DataFrameAnalyticsConfig> updatedConfigHolder = new AtomicReference<>();
|
||||
|
@ -212,6 +212,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
|||
updatedConfigHolder,
|
||||
exceptionHolder);
|
||||
|
||||
assertNoException(exceptionHolder);
|
||||
assertThat(updatedConfigHolder.get(), is(notNullValue()));
|
||||
assertThat(
|
||||
updatedConfigHolder.get(),
|
||||
|
@ -220,7 +221,6 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
|||
.setDescription("description-1")
|
||||
.setModelMemoryLimit(new ByteSizeValue(1024))
|
||||
.build())));
|
||||
assertThat(exceptionHolder.get(), is(nullValue()));
|
||||
}
|
||||
{ // Noop update
|
||||
AtomicReference<DataFrameAnalyticsConfig> updatedConfigHolder = new AtomicReference<>();
|
||||
|
@ -233,6 +233,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
|||
updatedConfigHolder,
|
||||
exceptionHolder);
|
||||
|
||||
assertNoException(exceptionHolder);
|
||||
assertThat(updatedConfigHolder.get(), is(notNullValue()));
|
||||
assertThat(
|
||||
updatedConfigHolder.get(),
|
||||
|
@ -241,7 +242,6 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
|||
.setDescription("description-1")
|
||||
.setModelMemoryLimit(new ByteSizeValue(1024))
|
||||
.build())));
|
||||
assertThat(exceptionHolder.get(), is(nullValue()));
|
||||
}
|
||||
{ // Update that changes both description and model memory limit
|
||||
AtomicReference<DataFrameAnalyticsConfig> updatedConfigHolder = new AtomicReference<>();
|
||||
|
@ -258,6 +258,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
|||
updatedConfigHolder,
|
||||
exceptionHolder);
|
||||
|
||||
assertNoException(exceptionHolder);
|
||||
assertThat(updatedConfigHolder.get(), is(notNullValue()));
|
||||
assertThat(
|
||||
updatedConfigHolder.get(),
|
||||
|
@ -266,7 +267,6 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
|||
.setDescription("description-2")
|
||||
.setModelMemoryLimit(new ByteSizeValue(2048))
|
||||
.build())));
|
||||
assertThat(exceptionHolder.get(), is(nullValue()));
|
||||
}
|
||||
{ // Update that applies security headers
|
||||
Map<String, String> securityHeaders = Collections.singletonMap("_xpack_security_authentication", "dummy");
|
||||
|
@ -281,6 +281,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
|||
updatedConfigHolder,
|
||||
exceptionHolder);
|
||||
|
||||
assertNoException(exceptionHolder);
|
||||
assertThat(updatedConfigHolder.get(), is(notNullValue()));
|
||||
assertThat(
|
||||
updatedConfigHolder.get(),
|
||||
|
@ -290,7 +291,6 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
|||
.setModelMemoryLimit(new ByteSizeValue(2048))
|
||||
.setHeaders(securityHeaders)
|
||||
.build())));
|
||||
assertThat(exceptionHolder.get(), is(nullValue()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -371,6 +371,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
|||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,7 +28,9 @@ public class TimeBasedExtractedFields extends ExtractedFields {
|
|||
private final ExtractedField timeField;
|
||||
|
||||
public TimeBasedExtractedFields(ExtractedField timeField, List<ExtractedField> allFields) {
|
||||
super(allFields, Collections.emptyMap());
|
||||
super(allFields,
|
||||
Collections.emptyList(),
|
||||
Collections.emptyMap());
|
||||
if (!allFields.contains(timeField)) {
|
||||
throw new IllegalArgumentException("timeField should also be contained in allFields");
|
||||
}
|
||||
|
|
|
@ -28,15 +28,18 @@ import org.elasticsearch.search.fetch.StoredFieldsContext;
|
|||
import org.elasticsearch.search.sort.SortOrder;
|
||||
import org.elasticsearch.xpack.core.ClientHelper;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.ml.dataframe.DestinationIndex;
|
||||
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitter;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
import org.elasticsearch.xpack.ml.extractor.ProcessedField;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.NoSuchElementException;
|
||||
|
@ -46,6 +49,7 @@ import java.util.Set;
|
|||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
/**
|
||||
* An implementation that extracts data from elasticsearch using search and scroll on a client.
|
||||
|
@ -67,10 +71,29 @@ public class DataFrameDataExtractor {
|
|||
private boolean hasNext;
|
||||
private boolean searchHasShardFailure;
|
||||
private final CachedSupplier<TrainTestSplitter> trainTestSplitter;
|
||||
// These are fields that are sent directly to the analytics process
|
||||
// They are not passed through a feature_processor
|
||||
private final String[] organicFeatures;
|
||||
// These are the output field names for the feature_processors
|
||||
private final String[] processedFeatures;
|
||||
private final Map<String, ExtractedField> extractedFieldsByName;
|
||||
|
||||
DataFrameDataExtractor(Client client, DataFrameDataExtractorContext context) {
|
||||
this.client = Objects.requireNonNull(client);
|
||||
this.context = Objects.requireNonNull(context);
|
||||
Set<String> processedFieldInputs = context.extractedFields.getProcessedFieldInputs();
|
||||
this.organicFeatures = context.extractedFields.getAllFields()
|
||||
.stream()
|
||||
.map(ExtractedField::getName)
|
||||
.filter(f -> processedFieldInputs.contains(f) == false)
|
||||
.toArray(String[]::new);
|
||||
this.processedFeatures = context.extractedFields.getProcessedFields()
|
||||
.stream()
|
||||
.map(ProcessedField::getOutputFieldNames)
|
||||
.flatMap(List::stream)
|
||||
.toArray(String[]::new);
|
||||
this.extractedFieldsByName = new LinkedHashMap<>();
|
||||
context.extractedFields.getAllFields().forEach(f -> this.extractedFieldsByName.put(f.getName(), f));
|
||||
hasNext = true;
|
||||
searchHasShardFailure = false;
|
||||
this.trainTestSplitter = new CachedSupplier<>(context.trainTestSplitterFactory::create);
|
||||
|
@ -188,26 +211,78 @@ public class DataFrameDataExtractor {
|
|||
return rows;
|
||||
}
|
||||
|
||||
private String extractNonProcessedValues(SearchHit hit, String organicFeature) {
|
||||
ExtractedField field = extractedFieldsByName.get(organicFeature);
|
||||
Object[] values = field.value(hit);
|
||||
if (values.length == 1 && isValidValue(values[0])) {
|
||||
return Objects.toString(values[0]);
|
||||
}
|
||||
if (values.length == 0 && context.supportsRowsWithMissingValues) {
|
||||
// if values is empty then it means it's a missing value
|
||||
return NULL_VALUE;
|
||||
}
|
||||
// we are here if we have a missing value but the analysis does not support those
|
||||
// or the value type is not supported (e.g. arrays, etc.)
|
||||
return null;
|
||||
}
|
||||
|
||||
private String[] extractProcessedValue(ProcessedField processedField, SearchHit hit) {
|
||||
Object[] values = processedField.value(hit, extractedFieldsByName::get);
|
||||
if (values.length == 0 && context.supportsRowsWithMissingValues == false) {
|
||||
return null;
|
||||
}
|
||||
final String[] extractedValue = new String[processedField.getOutputFieldNames().size()];
|
||||
for (int i = 0; i < processedField.getOutputFieldNames().size(); i++) {
|
||||
extractedValue[i] = NULL_VALUE;
|
||||
}
|
||||
// if values is empty then it means it's a missing value
|
||||
if (values.length == 0) {
|
||||
return extractedValue;
|
||||
}
|
||||
|
||||
if (values.length != processedField.getOutputFieldNames().size()) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"field_processor [{}] output size expected to be [{}], instead it was [{}]",
|
||||
processedField.getProcessorName(),
|
||||
processedField.getOutputFieldNames().size(),
|
||||
values.length);
|
||||
}
|
||||
|
||||
for (int i = 0; i < processedField.getOutputFieldNames().size(); ++i) {
|
||||
Object value = values[i];
|
||||
if (value == null && context.supportsRowsWithMissingValues) {
|
||||
continue;
|
||||
}
|
||||
if (isValidValue(value) == false) {
|
||||
// we are here if we have a missing value but the analysis does not support those
|
||||
// or the value type is not supported (e.g. arrays, etc.)
|
||||
return null;
|
||||
}
|
||||
extractedValue[i] = Objects.toString(value);
|
||||
}
|
||||
return extractedValue;
|
||||
}
|
||||
|
||||
private Row createRow(SearchHit hit) {
|
||||
String[] extractedValues = new String[context.extractedFields.getAllFields().size()];
|
||||
for (int i = 0; i < extractedValues.length; ++i) {
|
||||
ExtractedField field = context.extractedFields.getAllFields().get(i);
|
||||
Object[] values = field.value(hit);
|
||||
if (values.length == 1 && (values[0] instanceof Number || values[0] instanceof String)) {
|
||||
extractedValues[i] = Objects.toString(values[0]);
|
||||
} else {
|
||||
if (values.length == 0 && context.supportsRowsWithMissingValues) {
|
||||
// if values is empty then it means it's a missing value
|
||||
extractedValues[i] = NULL_VALUE;
|
||||
} else {
|
||||
// we are here if we have a missing value but the analysis does not support those
|
||||
// or the value type is not supported (e.g. arrays, etc.)
|
||||
extractedValues = null;
|
||||
break;
|
||||
}
|
||||
String[] extractedValues = new String[organicFeatures.length + processedFeatures.length];
|
||||
int i = 0;
|
||||
for (String organicFeature : organicFeatures) {
|
||||
String extractedValue = extractNonProcessedValues(hit, organicFeature);
|
||||
if (extractedValue == null) {
|
||||
return new Row(null, hit, true);
|
||||
}
|
||||
extractedValues[i++] = extractedValue;
|
||||
}
|
||||
for (ProcessedField processedField : context.extractedFields.getProcessedFields()) {
|
||||
String[] processedValues = extractProcessedValue(processedField, hit);
|
||||
if (processedValues == null) {
|
||||
return new Row(null, hit, true);
|
||||
}
|
||||
for (String processedValue : processedValues) {
|
||||
extractedValues[i++] = processedValue;
|
||||
}
|
||||
}
|
||||
boolean isTraining = extractedValues == null ? false : trainTestSplitter.get().isTraining(extractedValues);
|
||||
boolean isTraining = trainTestSplitter.get().isTraining(extractedValues);
|
||||
return new Row(extractedValues, hit, isTraining);
|
||||
}
|
||||
|
||||
|
@ -241,7 +316,7 @@ public class DataFrameDataExtractor {
|
|||
}
|
||||
|
||||
public List<String> getFieldNames() {
|
||||
return context.extractedFields.getAllFields().stream().map(ExtractedField::getName).collect(Collectors.toList());
|
||||
return Stream.concat(Arrays.stream(organicFeatures), Arrays.stream(processedFeatures)).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public ExtractedFields getExtractedFields() {
|
||||
|
@ -253,12 +328,12 @@ public class DataFrameDataExtractor {
|
|||
SearchResponse searchResponse = executeSearchRequest(searchRequestBuilder);
|
||||
long rows = searchResponse.getHits().getTotalHits().value;
|
||||
LOGGER.debug("[{}] Data summary rows [{}]", context.jobId, rows);
|
||||
return new DataSummary(rows, context.extractedFields.getAllFields().size());
|
||||
return new DataSummary(rows, organicFeatures.length + processedFeatures.length);
|
||||
}
|
||||
|
||||
public void collectDataSummaryAsync(ActionListener<DataSummary> dataSummaryActionListener) {
|
||||
SearchRequestBuilder searchRequestBuilder = buildDataSummarySearchRequestBuilder();
|
||||
final int numberOfFields = context.extractedFields.getAllFields().size();
|
||||
final int numberOfFields = organicFeatures.length + processedFeatures.length;
|
||||
|
||||
ClientHelper.executeWithHeadersAsync(context.headers,
|
||||
ClientHelper.ML_ORIGIN,
|
||||
|
@ -298,7 +373,11 @@ public class DataFrameDataExtractor {
|
|||
}
|
||||
|
||||
public Set<String> getCategoricalFields(DataFrameAnalysis analysis) {
|
||||
return ExtractedFieldsDetector.getCategoricalFields(context.extractedFields, analysis);
|
||||
return ExtractedFieldsDetector.getCategoricalOutputFields(context.extractedFields, analysis);
|
||||
}
|
||||
|
||||
private static boolean isValidValue(Object value) {
|
||||
return value instanceof Number || value instanceof String;
|
||||
}
|
||||
|
||||
public static class DataSummary {
|
||||
|
|
|
@ -13,27 +13,33 @@ import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
|
|||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.regex.Regex;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.index.IndexSettings;
|
||||
import org.elasticsearch.index.mapper.BooleanFieldMapper;
|
||||
import org.elasticsearch.index.mapper.ObjectMapper;
|
||||
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.FieldCardinalityConstraint;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.RequiredField;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Types;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NameResolver;
|
||||
import org.elasticsearch.xpack.ml.dataframe.DestinationIndex;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
import org.elasticsearch.xpack.ml.extractor.ProcessedField;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
import java.util.Iterator;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
|
@ -60,7 +66,9 @@ public class ExtractedFieldsDetector {
|
|||
private final FieldCapabilitiesResponse fieldCapabilitiesResponse;
|
||||
private final Map<String, Long> cardinalitiesForFieldsWithConstraints;
|
||||
|
||||
ExtractedFieldsDetector(DataFrameAnalyticsConfig config, int docValueFieldsLimit, FieldCapabilitiesResponse fieldCapabilitiesResponse,
|
||||
ExtractedFieldsDetector(DataFrameAnalyticsConfig config,
|
||||
int docValueFieldsLimit,
|
||||
FieldCapabilitiesResponse fieldCapabilitiesResponse,
|
||||
Map<String, Long> cardinalitiesForFieldsWithConstraints) {
|
||||
this.config = Objects.requireNonNull(config);
|
||||
this.docValueFieldsLimit = docValueFieldsLimit;
|
||||
|
@ -69,23 +77,39 @@ public class ExtractedFieldsDetector {
|
|||
}
|
||||
|
||||
public Tuple<ExtractedFields, List<FieldSelection>> detect() {
|
||||
List<ProcessedField> processedFields = extractFeatureProcessors()
|
||||
.stream()
|
||||
.map(ProcessedField::new)
|
||||
.collect(Collectors.toList());
|
||||
TreeSet<FieldSelection> fieldSelection = new TreeSet<>(Comparator.comparing(FieldSelection::getName));
|
||||
Set<String> fields = getIncludedFields(fieldSelection);
|
||||
Set<String> fields = getIncludedFields(fieldSelection,
|
||||
processedFields.stream()
|
||||
.map(ProcessedField::getInputFieldNames)
|
||||
.flatMap(List::stream)
|
||||
.collect(Collectors.toSet()));
|
||||
checkFieldsHaveCompatibleTypes(fields);
|
||||
checkRequiredFields(fields);
|
||||
checkFieldsWithCardinalityLimit();
|
||||
ExtractedFields extractedFields = detectExtractedFields(fields, fieldSelection);
|
||||
ExtractedFields extractedFields = detectExtractedFields(fields, fieldSelection, processedFields);
|
||||
addIncludedFields(extractedFields, fieldSelection);
|
||||
|
||||
checkOutputFeatureUniqueness(processedFields, fields);
|
||||
|
||||
return Tuple.tuple(extractedFields, Collections.unmodifiableList(new ArrayList<>(fieldSelection)));
|
||||
}
|
||||
|
||||
private Set<String> getIncludedFields(Set<FieldSelection> fieldSelection) {
|
||||
private Set<String> getIncludedFields(Set<FieldSelection> fieldSelection, Set<String> requiredFieldsForProcessors) {
|
||||
Set<String> fields = new TreeSet<>(fieldCapabilitiesResponse.get().keySet());
|
||||
validateFieldsRequireForProcessors(requiredFieldsForProcessors);
|
||||
fields.removeAll(IGNORE_FIELDS);
|
||||
removeFieldsUnderResultsField(fields);
|
||||
removeObjects(fields);
|
||||
applySourceFiltering(fields);
|
||||
if (fields.containsAll(requiredFieldsForProcessors) == false) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"fields {} required by field_processors are not included in source filtering.",
|
||||
Sets.difference(requiredFieldsForProcessors, fields));
|
||||
}
|
||||
FetchSourceContext analyzedFields = config.getAnalyzedFields();
|
||||
|
||||
// If the user has not explicitly included fields we'll include all compatible fields
|
||||
|
@ -93,20 +117,63 @@ public class ExtractedFieldsDetector {
|
|||
removeFieldsWithIncompatibleTypes(fields, fieldSelection);
|
||||
}
|
||||
includeAndExcludeFields(fields, fieldSelection);
|
||||
if (fields.containsAll(requiredFieldsForProcessors) == false) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"fields {} required by field_processors are not included in the analyzed_fields.",
|
||||
Sets.difference(requiredFieldsForProcessors, fields));
|
||||
}
|
||||
|
||||
return fields;
|
||||
}
|
||||
|
||||
private void validateFieldsRequireForProcessors(Set<String> processorFields) {
|
||||
Set<String> fieldsForProcessor = new HashSet<>(processorFields);
|
||||
removeFieldsUnderResultsField(fieldsForProcessor);
|
||||
if (fieldsForProcessor.size() < processorFields.size()) {
|
||||
throw ExceptionsHelper.badRequestException("fields contained in results field [{}] cannot be used in a feature_processor",
|
||||
config.getDest().getResultsField());
|
||||
}
|
||||
removeObjects(fieldsForProcessor);
|
||||
if (fieldsForProcessor.size() < processorFields.size()) {
|
||||
throw ExceptionsHelper.badRequestException("fields for feature_processors must not be objects");
|
||||
}
|
||||
fieldsForProcessor.removeAll(IGNORE_FIELDS);
|
||||
if (fieldsForProcessor.size() < processorFields.size()) {
|
||||
throw ExceptionsHelper.badRequestException("the following fields cannot be used in feature_processors {}", IGNORE_FIELDS);
|
||||
}
|
||||
List<String> fieldsMissingInMapping = processorFields.stream()
|
||||
.filter(f -> fieldCapabilitiesResponse.get().containsKey(f) == false)
|
||||
.collect(Collectors.toList());
|
||||
if (fieldsMissingInMapping.isEmpty() == false) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"the fields {} were not found in the field capabilities of the source indices [{}]. "
|
||||
+ "Fields must exist and be mapped to be used in feature_processors.",
|
||||
fieldsMissingInMapping,
|
||||
Strings.arrayToCommaDelimitedString(config.getSource().getIndex()));
|
||||
}
|
||||
List<String> processedRequiredFields = config.getAnalysis()
|
||||
.getRequiredFields()
|
||||
.stream()
|
||||
.map(RequiredField::getName)
|
||||
.filter(processorFields::contains)
|
||||
.collect(Collectors.toList());
|
||||
if (processedRequiredFields.isEmpty() == false) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"required analysis fields {} cannot be used in a feature_processor",
|
||||
processedRequiredFields);
|
||||
}
|
||||
}
|
||||
|
||||
private void removeFieldsUnderResultsField(Set<String> fields) {
|
||||
String resultsField = config.getDest().getResultsField();
|
||||
final String resultsFieldPrefix = config.getDest().getResultsField() + ".";
|
||||
Iterator<String> fieldsIterator = fields.iterator();
|
||||
while (fieldsIterator.hasNext()) {
|
||||
String field = fieldsIterator.next();
|
||||
if (field.startsWith(resultsField + ".")) {
|
||||
if (field.startsWith(resultsFieldPrefix)) {
|
||||
fieldsIterator.remove();
|
||||
}
|
||||
}
|
||||
fields.removeIf(field -> field.startsWith(resultsField + "."));
|
||||
fields.removeIf(field -> field.startsWith(resultsFieldPrefix));
|
||||
}
|
||||
|
||||
private void removeObjects(Set<String> fields) {
|
||||
|
@ -287,9 +354,23 @@ public class ExtractedFieldsDetector {
|
|||
}
|
||||
}
|
||||
|
||||
private ExtractedFields detectExtractedFields(Set<String> fields, Set<FieldSelection> fieldSelection) {
|
||||
ExtractedFields extractedFields = ExtractedFields.build(fields, Collections.emptySet(), fieldCapabilitiesResponse,
|
||||
cardinalitiesForFieldsWithConstraints);
|
||||
private List<PreProcessor> extractFeatureProcessors() {
|
||||
if (config.getAnalysis() instanceof Classification) {
|
||||
return ((Classification)config.getAnalysis()).getFeatureProcessors();
|
||||
} else if (config.getAnalysis() instanceof Regression) {
|
||||
return ((Regression)config.getAnalysis()).getFeatureProcessors();
|
||||
}
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
private ExtractedFields detectExtractedFields(Set<String> fields,
|
||||
Set<FieldSelection> fieldSelection,
|
||||
List<ProcessedField> processedFields) {
|
||||
ExtractedFields extractedFields = ExtractedFields.build(fields,
|
||||
Collections.emptySet(),
|
||||
fieldCapabilitiesResponse,
|
||||
cardinalitiesForFieldsWithConstraints,
|
||||
processedFields);
|
||||
boolean preferSource = extractedFields.getDocValueFields().size() > docValueFieldsLimit;
|
||||
extractedFields = deduplicateMultiFields(extractedFields, preferSource, fieldSelection);
|
||||
if (preferSource) {
|
||||
|
@ -304,10 +385,15 @@ public class ExtractedFieldsDetector {
|
|||
return extractedFields;
|
||||
}
|
||||
|
||||
private ExtractedFields deduplicateMultiFields(ExtractedFields extractedFields, boolean preferSource,
|
||||
private ExtractedFields deduplicateMultiFields(ExtractedFields extractedFields,
|
||||
boolean preferSource,
|
||||
Set<FieldSelection> fieldSelection) {
|
||||
Set<String> requiredFields = config.getAnalysis().getRequiredFields().stream().map(RequiredField::getName)
|
||||
Set<String> requiredFields = config.getAnalysis()
|
||||
.getRequiredFields()
|
||||
.stream()
|
||||
.map(RequiredField::getName)
|
||||
.collect(Collectors.toSet());
|
||||
Set<String> processorInputFields = extractedFields.getProcessedFieldInputs();
|
||||
Map<String, ExtractedField> nameOrParentToField = new LinkedHashMap<>();
|
||||
for (ExtractedField currentField : extractedFields.getAllFields()) {
|
||||
String nameOrParent = currentField.isMultiField() ? currentField.getParentField() : currentField.getName();
|
||||
|
@ -315,15 +401,37 @@ public class ExtractedFieldsDetector {
|
|||
if (existingField != null) {
|
||||
ExtractedField parent = currentField.isMultiField() ? existingField : currentField;
|
||||
ExtractedField multiField = currentField.isMultiField() ? currentField : existingField;
|
||||
// If required fields contains parent or multifield and the processor input fields reference the other, that is an error
|
||||
// we should not allow processing of data that is required.
|
||||
if ((requiredFields.contains(parent.getName()) && processorInputFields.contains(multiField.getName()))
|
||||
|| (requiredFields.contains(multiField.getName()) && processorInputFields.contains(parent.getName()))) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"feature_processors cannot be applied to required fields for analysis; multi-field [{}] parent [{}]",
|
||||
multiField.getName(),
|
||||
parent.getName());
|
||||
}
|
||||
// If processor input fields have BOTH, we need to keep both.
|
||||
if (processorInputFields.contains(parent.getName()) && processorInputFields.contains(multiField.getName())) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"feature_processors refer to both multi-field [{}] and parent [{}]. Please only refer to one or the other",
|
||||
multiField.getName(),
|
||||
parent.getName());
|
||||
}
|
||||
nameOrParentToField.put(nameOrParent,
|
||||
chooseMultiFieldOrParent(preferSource, requiredFields, parent, multiField, fieldSelection));
|
||||
chooseMultiFieldOrParent(preferSource, requiredFields, processorInputFields, parent, multiField, fieldSelection));
|
||||
}
|
||||
}
|
||||
return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()), cardinalitiesForFieldsWithConstraints);
|
||||
return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()),
|
||||
extractedFields.getProcessedFields(),
|
||||
cardinalitiesForFieldsWithConstraints);
|
||||
}
|
||||
|
||||
private ExtractedField chooseMultiFieldOrParent(boolean preferSource, Set<String> requiredFields, ExtractedField parent,
|
||||
ExtractedField multiField, Set<FieldSelection> fieldSelection) {
|
||||
private ExtractedField chooseMultiFieldOrParent(boolean preferSource,
|
||||
Set<String> requiredFields,
|
||||
Set<String> processorInputFields,
|
||||
ExtractedField parent,
|
||||
ExtractedField multiField,
|
||||
Set<FieldSelection> fieldSelection) {
|
||||
// Check requirements first
|
||||
if (requiredFields.contains(parent.getName())) {
|
||||
addExcludedField(multiField.getName(), "[" + parent.getName() + "] is required instead", fieldSelection);
|
||||
|
@ -333,6 +441,19 @@ public class ExtractedFieldsDetector {
|
|||
addExcludedField(parent.getName(), "[" + multiField.getName() + "] is required instead", fieldSelection);
|
||||
return multiField;
|
||||
}
|
||||
// Choose the one required by our processors
|
||||
if (processorInputFields.contains(parent.getName())) {
|
||||
addExcludedField(multiField.getName(),
|
||||
"[" + parent.getName() + "] is referenced by feature_processors instead",
|
||||
fieldSelection);
|
||||
return parent;
|
||||
}
|
||||
if (processorInputFields.contains(multiField.getName())) {
|
||||
addExcludedField(parent.getName(),
|
||||
"[" + multiField.getName() + "] is referenced by feature_processors instead",
|
||||
fieldSelection);
|
||||
return multiField;
|
||||
}
|
||||
|
||||
// If both are multi-fields it means there are several. In this case parent is the previous multi-field
|
||||
// we selected. We'll just keep that.
|
||||
|
@ -370,7 +491,9 @@ public class ExtractedFieldsDetector {
|
|||
for (ExtractedField field : extractedFields.getAllFields()) {
|
||||
adjusted.add(field.supportsFromSource() ? field.newFromSource() : field);
|
||||
}
|
||||
return new ExtractedFields(adjusted, cardinalitiesForFieldsWithConstraints);
|
||||
return new ExtractedFields(adjusted,
|
||||
extractedFields.getProcessedFields(),
|
||||
cardinalitiesForFieldsWithConstraints);
|
||||
}
|
||||
|
||||
private ExtractedFields fetchBooleanFieldsAsIntegers(ExtractedFields extractedFields) {
|
||||
|
@ -387,13 +510,15 @@ public class ExtractedFieldsDetector {
|
|||
adjusted.add(field);
|
||||
}
|
||||
}
|
||||
return new ExtractedFields(adjusted, cardinalitiesForFieldsWithConstraints);
|
||||
return new ExtractedFields(adjusted,
|
||||
extractedFields.getProcessedFields(),
|
||||
cardinalitiesForFieldsWithConstraints);
|
||||
}
|
||||
|
||||
private void addIncludedFields(ExtractedFields extractedFields, Set<FieldSelection> fieldSelection) {
|
||||
Set<String> requiredFields = config.getAnalysis().getRequiredFields().stream().map(RequiredField::getName)
|
||||
.collect(Collectors.toSet());
|
||||
Set<String> categoricalFields = getCategoricalFields(extractedFields, config.getAnalysis());
|
||||
Set<String> categoricalFields = getCategoricalInputFields(extractedFields, config.getAnalysis());
|
||||
for (ExtractedField includedField : extractedFields.getAllFields()) {
|
||||
FieldSelection.FeatureType featureType = categoricalFields.contains(includedField.getName()) ?
|
||||
FieldSelection.FeatureType.CATEGORICAL : FieldSelection.FeatureType.NUMERICAL;
|
||||
|
@ -402,7 +527,38 @@ public class ExtractedFieldsDetector {
|
|||
}
|
||||
}
|
||||
|
||||
static Set<String> getCategoricalFields(ExtractedFields extractedFields, DataFrameAnalysis analysis) {
|
||||
static void checkOutputFeatureUniqueness(List<ProcessedField> processedFields, Set<String> selectedFields) {
|
||||
Set<String> processInputs = processedFields.stream()
|
||||
.map(ProcessedField::getInputFieldNames)
|
||||
.flatMap(List::stream)
|
||||
.collect(Collectors.toSet());
|
||||
// All analysis fields that we include that are NOT processed
|
||||
// This indicates that they are sent as is
|
||||
Set<String> organicFields = Sets.difference(selectedFields, processInputs);
|
||||
|
||||
Set<String> processedFeatures = new HashSet<>();
|
||||
Set<String> duplicatedFields = new HashSet<>();
|
||||
for (ProcessedField processedField : processedFields) {
|
||||
for (String output : processedField.getOutputFieldNames()) {
|
||||
if (processedFeatures.add(output) == false) {
|
||||
duplicatedFields.add(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (duplicatedFields.isEmpty() == false) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"feature_processors must define unique output field names; duplicate fields {}",
|
||||
duplicatedFields);
|
||||
}
|
||||
Set<String> duplicateOrganicAndProcessed = Sets.intersection(organicFields, processedFeatures);
|
||||
if (duplicateOrganicAndProcessed.isEmpty() == false) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"feature_processors output fields must not include non-processed analysis fields; duplicate fields {}",
|
||||
duplicateOrganicAndProcessed);
|
||||
}
|
||||
}
|
||||
|
||||
static Set<String> getCategoricalInputFields(ExtractedFields extractedFields, DataFrameAnalysis analysis) {
|
||||
return extractedFields.getAllFields().stream()
|
||||
.filter(extractedField -> analysis.getAllowedCategoricalTypes(extractedField.getName())
|
||||
.containsAll(extractedField.getTypes()))
|
||||
|
@ -410,6 +566,25 @@ public class ExtractedFieldsDetector {
|
|||
.collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
static Set<String> getCategoricalOutputFields(ExtractedFields extractedFields, DataFrameAnalysis analysis) {
|
||||
Set<String> processInputFields = extractedFields.getProcessedFieldInputs();
|
||||
Set<String> categoricalFields = extractedFields.getAllFields().stream()
|
||||
.filter(extractedField -> analysis.getAllowedCategoricalTypes(extractedField.getName())
|
||||
.containsAll(extractedField.getTypes()))
|
||||
.map(ExtractedField::getName)
|
||||
.filter(name -> processInputFields.contains(name) == false)
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
extractedFields.getProcessedFields().forEach(processedField ->
|
||||
processedField.getOutputFieldNames().forEach(outputField -> {
|
||||
if (analysis.getAllowedCategoricalTypes(outputField).containsAll(processedField.getOutputFieldType(outputField))) {
|
||||
categoricalFields.add(outputField);
|
||||
}
|
||||
})
|
||||
);
|
||||
return Collections.unmodifiableSet(categoricalFields);
|
||||
}
|
||||
|
||||
private static boolean isBoolean(Set<String> types) {
|
||||
return types.size() == 1 && types.contains(BooleanFieldMapper.CONTENT_TYPE);
|
||||
}
|
||||
|
|
|
@ -178,7 +178,7 @@ public class AnalyticsProcessManager {
|
|||
AnalyticsProcess<AnalyticsResult> process = processContext.process.get();
|
||||
AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get();
|
||||
try {
|
||||
writeHeaderRecord(dataExtractor, process);
|
||||
writeHeaderRecord(dataExtractor, process, task);
|
||||
writeDataRows(dataExtractor, process, task);
|
||||
process.writeEndOfDataMessage();
|
||||
process.flushStream();
|
||||
|
@ -268,8 +268,11 @@ public class AnalyticsProcessManager {
|
|||
}
|
||||
}
|
||||
|
||||
private void writeHeaderRecord(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process) throws IOException {
|
||||
private void writeHeaderRecord(DataFrameDataExtractor dataExtractor,
|
||||
AnalyticsProcess<AnalyticsResult> process,
|
||||
DataFrameAnalyticsTask task) throws IOException {
|
||||
List<String> fieldNames = dataExtractor.getFieldNames();
|
||||
LOGGER.debug(() -> new ParameterizedMessage("[{}] header row fields {}", task.getParams().getId(), fieldNames));
|
||||
|
||||
// We add 2 extra fields, both named dot:
|
||||
// - the document hash
|
||||
|
|
|
@ -9,6 +9,7 @@ package org.elasticsearch.xpack.ml.dataframe.process;
|
|||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.LatchedActionListener;
|
||||
|
@ -22,6 +23,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.security.user.XPackUser;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
|
||||
|
@ -34,6 +36,7 @@ import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
|||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
@ -191,8 +194,21 @@ public class ChunkedTrainedModelPersister {
|
|||
return latch;
|
||||
}
|
||||
|
||||
private long customProcessorSize() {
|
||||
List<PreProcessor> preProcessors = new ArrayList<>();
|
||||
if (analytics.getAnalysis() instanceof Classification) {
|
||||
preProcessors = ((Classification) analytics.getAnalysis()).getFeatureProcessors();
|
||||
} else if (analytics.getAnalysis() instanceof Regression) {
|
||||
preProcessors = ((Regression) analytics.getAnalysis()).getFeatureProcessors();
|
||||
}
|
||||
return preProcessors.stream().mapToLong(PreProcessor::ramBytesUsed).sum()
|
||||
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * preProcessors.size();
|
||||
}
|
||||
|
||||
private TrainedModelConfig createTrainedModelConfig(ModelSizeInfo modelSize) {
|
||||
Instant createTime = Instant.now();
|
||||
// The native process does not provide estimates for the custom feature_processor objects
|
||||
long customProcessorSize = customProcessorSize();
|
||||
String modelId = analytics.getId() + "-" + createTime.toEpochMilli();
|
||||
currentModelId.set(modelId);
|
||||
List<ExtractedField> fieldNames = extractedFields.getAllFields();
|
||||
|
@ -214,7 +230,7 @@ public class ChunkedTrainedModelPersister {
|
|||
.setDescription(analytics.getDescription())
|
||||
.setMetadata(Collections.singletonMap("analytics_config",
|
||||
XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true)))
|
||||
.setEstimatedHeapMemory(modelSize.ramBytesUsed())
|
||||
.setEstimatedHeapMemory(modelSize.ramBytesUsed() + customProcessorSize)
|
||||
.setEstimatedOperations(modelSize.numOperations())
|
||||
.setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable))
|
||||
.setLicenseLevel(License.OperationMode.PLATINUM.description())
|
||||
|
|
|
@ -12,7 +12,7 @@ import org.elasticsearch.index.mapper.BooleanFieldMapper;
|
|||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
@ -21,27 +21,39 @@ import java.util.Set;
|
|||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* The fields the datafeed has to extract
|
||||
* The fields the data[feed|frame] has to extract
|
||||
*/
|
||||
public class ExtractedFields {
|
||||
|
||||
private final List<ExtractedField> allFields;
|
||||
private final List<ExtractedField> docValueFields;
|
||||
private final List<ProcessedField> processedFields;
|
||||
private final String[] sourceFields;
|
||||
private final Map<String, Long> cardinalitiesForFieldsWithConstraints;
|
||||
|
||||
public ExtractedFields(List<ExtractedField> allFields, Map<String, Long> cardinalitiesForFieldsWithConstraints) {
|
||||
this.allFields = Collections.unmodifiableList(allFields);
|
||||
public ExtractedFields(List<ExtractedField> allFields,
|
||||
List<ProcessedField> processedFields,
|
||||
Map<String, Long> cardinalitiesForFieldsWithConstraints) {
|
||||
this.allFields = new ArrayList<>(allFields);
|
||||
this.docValueFields = filterFields(ExtractedField.Method.DOC_VALUE, allFields);
|
||||
this.sourceFields = filterFields(ExtractedField.Method.SOURCE, allFields).stream().map(ExtractedField::getSearchField)
|
||||
.toArray(String[]::new);
|
||||
this.cardinalitiesForFieldsWithConstraints = Collections.unmodifiableMap(cardinalitiesForFieldsWithConstraints);
|
||||
this.processedFields = processedFields == null ? Collections.emptyList() : processedFields;
|
||||
}
|
||||
|
||||
public List<ProcessedField> getProcessedFields() {
|
||||
return processedFields;
|
||||
}
|
||||
|
||||
public List<ExtractedField> getAllFields() {
|
||||
return allFields;
|
||||
}
|
||||
|
||||
public Set<String> getProcessedFieldInputs() {
|
||||
return processedFields.stream().map(ProcessedField::getInputFieldNames).flatMap(List::stream).collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
public String[] getSourceFields() {
|
||||
return sourceFields;
|
||||
}
|
||||
|
@ -58,11 +70,15 @@ public class ExtractedFields {
|
|||
return fields.stream().filter(field -> field.getMethod() == method).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public static ExtractedFields build(Collection<String> allFields, Set<String> scriptFields,
|
||||
public static ExtractedFields build(Set<String> allFields,
|
||||
Set<String> scriptFields,
|
||||
FieldCapabilitiesResponse fieldsCapabilities,
|
||||
Map<String, Long> cardinalitiesForFieldsWithConstraints) {
|
||||
Map<String, Long> cardinalitiesForFieldsWithConstraints,
|
||||
List<ProcessedField> processedFields) {
|
||||
ExtractionMethodDetector extractionMethodDetector = new ExtractionMethodDetector(scriptFields, fieldsCapabilities);
|
||||
return new ExtractedFields(allFields.stream().map(field -> extractionMethodDetector.detect(field)).collect(Collectors.toList()),
|
||||
return new ExtractedFields(
|
||||
allFields.stream().map(extractionMethodDetector::detect).collect(Collectors.toList()),
|
||||
processedFields,
|
||||
cardinalitiesForFieldsWithConstraints);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.extractor;
|
||||
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
|
||||
public class ProcessedField {
|
||||
private final PreProcessor preProcessor;
|
||||
|
||||
public ProcessedField(PreProcessor processor) {
|
||||
this.preProcessor = Objects.requireNonNull(processor);
|
||||
}
|
||||
|
||||
public List<String> getInputFieldNames() {
|
||||
return preProcessor.inputFields();
|
||||
}
|
||||
|
||||
public List<String> getOutputFieldNames() {
|
||||
return preProcessor.outputFields();
|
||||
}
|
||||
|
||||
public Set<String> getOutputFieldType(String outputField) {
|
||||
return Collections.singleton(preProcessor.getOutputFieldType(outputField));
|
||||
}
|
||||
|
||||
public Object[] value(SearchHit hit, Function<String, ExtractedField> fieldExtractor) {
|
||||
Map<String, Object> inputs = new HashMap<>(preProcessor.inputFields().size(), 1.0f);
|
||||
for (String field : preProcessor.inputFields()) {
|
||||
ExtractedField extractedField = fieldExtractor.apply(field);
|
||||
if (extractedField == null) {
|
||||
return new Object[0];
|
||||
}
|
||||
Object[] values = extractedField.value(hit);
|
||||
if (values == null || values.length == 0) {
|
||||
continue;
|
||||
}
|
||||
final Object value = values[0];
|
||||
if (values.length == 1 && (value instanceof String || value instanceof Number)) {
|
||||
inputs.put(field, value);
|
||||
}
|
||||
}
|
||||
preProcessor.process(inputs);
|
||||
return preProcessor.outputFields().stream().map(inputs::get).toArray();
|
||||
}
|
||||
|
||||
public String getProcessorName() {
|
||||
return preProcessor.getName();
|
||||
}
|
||||
|
||||
}
|
|
@ -128,4 +128,11 @@ public abstract class MlSingleNodeTestCase extends ESSingleNodeTestCase {
|
|||
return responseHolder.get();
|
||||
}
|
||||
|
||||
public static void assertNoException(AtomicReference<Exception> error) throws Exception {
|
||||
if (error.get() == null) {
|
||||
return;
|
||||
}
|
||||
throw error.get();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -15,8 +15,10 @@ import org.elasticsearch.action.search.SearchResponse;
|
|||
import org.elasticsearch.action.search.ShardSearchFailure;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
|
@ -27,10 +29,13 @@ import org.elasticsearch.threadpool.ThreadPool;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory;
|
||||
import org.elasticsearch.xpack.ml.extractor.DocValueField;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
import org.elasticsearch.xpack.ml.extractor.ProcessedField;
|
||||
import org.elasticsearch.xpack.ml.extractor.SourceField;
|
||||
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
|
||||
import org.junit.Before;
|
||||
|
@ -45,8 +50,10 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Queue;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
|
@ -83,7 +90,9 @@ public class DataFrameDataExtractorTests extends ESTestCase {
|
|||
query = QueryBuilders.matchAllQuery();
|
||||
extractedFields = new ExtractedFields(Arrays.asList(
|
||||
new DocValueField("field_1", Collections.singleton("keyword")),
|
||||
new DocValueField("field_2", Collections.singleton("keyword"))), Collections.emptyMap());
|
||||
new DocValueField("field_2", Collections.singleton("keyword"))),
|
||||
Collections.emptyList(),
|
||||
Collections.emptyMap());
|
||||
scrollSize = 1000;
|
||||
headers = Collections.emptyMap();
|
||||
|
||||
|
@ -304,7 +313,9 @@ public class DataFrameDataExtractorTests extends ESTestCase {
|
|||
// Explicit cast of ExtractedField args necessary for Eclipse due to https://bugs.eclipse.org/bugs/show_bug.cgi?id=530915
|
||||
extractedFields = new ExtractedFields(Arrays.asList(
|
||||
(ExtractedField) new DocValueField("field_1", Collections.singleton("keyword")),
|
||||
(ExtractedField) new SourceField("field_2", Collections.singleton("text"))), Collections.emptyMap());
|
||||
(ExtractedField) new SourceField("field_2", Collections.singleton("text"))),
|
||||
Collections.emptyList(),
|
||||
Collections.emptyMap());
|
||||
|
||||
TestExtractor dataExtractor = createExtractor(false, false);
|
||||
|
||||
|
@ -445,7 +456,9 @@ public class DataFrameDataExtractorTests extends ESTestCase {
|
|||
(ExtractedField) new DocValueField("field_integer", Collections.singleton("integer")),
|
||||
(ExtractedField) new DocValueField("field_long", Collections.singleton("long")),
|
||||
(ExtractedField) new DocValueField("field_keyword", Collections.singleton("keyword")),
|
||||
(ExtractedField) new SourceField("field_text", Collections.singleton("text"))), Collections.emptyMap());
|
||||
(ExtractedField) new SourceField("field_text", Collections.singleton("text"))),
|
||||
Collections.emptyList(),
|
||||
Collections.emptyMap());
|
||||
TestExtractor dataExtractor = createExtractor(true, true);
|
||||
|
||||
assertThat(dataExtractor.getCategoricalFields(OutlierDetectionTests.createRandom()), empty());
|
||||
|
@ -465,12 +478,100 @@ public class DataFrameDataExtractorTests extends ESTestCase {
|
|||
containsInAnyOrder("field_keyword", "field_text", "field_boolean"));
|
||||
}
|
||||
|
||||
public void testGetFieldNames_GivenProcessesFeatures() {
|
||||
// Explicit cast of ExtractedField args necessary for Eclipse due to https://bugs.eclipse.org/bugs/show_bug.cgi?id=530915
|
||||
extractedFields = new ExtractedFields(Arrays.asList(
|
||||
(ExtractedField) new DocValueField("field_boolean", Collections.singleton("boolean")),
|
||||
(ExtractedField) new DocValueField("field_float", Collections.singleton("float")),
|
||||
(ExtractedField) new DocValueField("field_double", Collections.singleton("double")),
|
||||
(ExtractedField) new DocValueField("field_byte", Collections.singleton("byte")),
|
||||
(ExtractedField) new DocValueField("field_short", Collections.singleton("short")),
|
||||
(ExtractedField) new DocValueField("field_integer", Collections.singleton("integer")),
|
||||
(ExtractedField) new DocValueField("field_long", Collections.singleton("long")),
|
||||
(ExtractedField) new DocValueField("field_keyword", Collections.singleton("keyword")),
|
||||
(ExtractedField) new SourceField("field_text", Collections.singleton("text"))),
|
||||
Arrays.asList(
|
||||
new ProcessedField(new CategoricalPreProcessor("field_long", "animal")),
|
||||
buildProcessedField("field_short", "field_1", "field_2")
|
||||
),
|
||||
Collections.emptyMap());
|
||||
TestExtractor dataExtractor = createExtractor(true, true);
|
||||
|
||||
assertThat(dataExtractor.getCategoricalFields(new Regression("field_double")),
|
||||
containsInAnyOrder("field_keyword", "field_text", "animal"));
|
||||
|
||||
List<String> fieldNames = dataExtractor.getFieldNames();
|
||||
assertThat(fieldNames, containsInAnyOrder(
|
||||
"animal",
|
||||
"field_1",
|
||||
"field_2",
|
||||
"field_boolean",
|
||||
"field_float",
|
||||
"field_double",
|
||||
"field_byte",
|
||||
"field_integer",
|
||||
"field_keyword",
|
||||
"field_text"));
|
||||
assertThat(dataExtractor.getFieldNames(), contains(fieldNames.toArray(new String[0])));
|
||||
}
|
||||
|
||||
public void testExtractionWithProcessedFeatures() throws IOException {
|
||||
extractedFields = new ExtractedFields(Arrays.asList(
|
||||
new DocValueField("field_1", Collections.singleton("keyword")),
|
||||
new DocValueField("field_2", Collections.singleton("keyword"))),
|
||||
Arrays.asList(
|
||||
new ProcessedField(new CategoricalPreProcessor("field_1", "animal")),
|
||||
new ProcessedField(new OneHotEncoding("field_1",
|
||||
Arrays.asList("11", "12")
|
||||
.stream()
|
||||
.collect(Collectors.toMap(Function.identity(), s -> s.equals("11") ? "field_11" : "field_12")),
|
||||
true))
|
||||
),
|
||||
Collections.emptyMap());
|
||||
|
||||
TestExtractor dataExtractor = createExtractor(true, true);
|
||||
|
||||
// First and only batch
|
||||
SearchResponse response1 = createSearchResponse(Arrays.asList(1_1, null, 1_3), Arrays.asList(2_1, 2_2, 2_3));
|
||||
dataExtractor.setNextResponse(response1);
|
||||
|
||||
// Empty
|
||||
SearchResponse lastAndEmptyResponse = createEmptySearchResponse();
|
||||
dataExtractor.setNextResponse(lastAndEmptyResponse);
|
||||
|
||||
assertThat(dataExtractor.hasNext(), is(true));
|
||||
|
||||
// First batch
|
||||
Optional<List<DataFrameDataExtractor.Row>> rows = dataExtractor.next();
|
||||
assertThat(rows.isPresent(), is(true));
|
||||
assertThat(rows.get().size(), equalTo(3));
|
||||
|
||||
assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"21", "dog", "1", "0"}));
|
||||
assertThat(rows.get().get(1).getValues(),
|
||||
equalTo(new String[] {"22", "dog", DataFrameDataExtractor.NULL_VALUE, DataFrameDataExtractor.NULL_VALUE}));
|
||||
assertThat(rows.get().get(2).getValues(), equalTo(new String[] {"23", "dog", "0", "0"}));
|
||||
|
||||
assertThat(rows.get().get(0).shouldSkip(), is(false));
|
||||
assertThat(rows.get().get(1).shouldSkip(), is(false));
|
||||
assertThat(rows.get().get(2).shouldSkip(), is(false));
|
||||
}
|
||||
|
||||
private TestExtractor createExtractor(boolean includeSource, boolean supportsRowsWithMissingValues) {
|
||||
DataFrameDataExtractorContext context = new DataFrameDataExtractorContext(JOB_ID, extractedFields, indices, query, scrollSize,
|
||||
headers, includeSource, supportsRowsWithMissingValues, trainTestSplitterFactory);
|
||||
return new TestExtractor(client, context);
|
||||
}
|
||||
|
||||
private static ProcessedField buildProcessedField(String inputField, String... outputFields) {
|
||||
return new ProcessedField(buildPreProcessor(inputField, outputFields));
|
||||
}
|
||||
|
||||
private static PreProcessor buildPreProcessor(String inputField, String... outputFields) {
|
||||
return new OneHotEncoding(inputField,
|
||||
Arrays.stream(outputFields).collect(Collectors.toMap((s) -> randomAlphaOfLength(10), Function.identity())),
|
||||
true);
|
||||
}
|
||||
|
||||
private SearchResponse createSearchResponse(List<Number> field1Values, List<Number> field2Values) {
|
||||
assertThat(field1Values.size(), equalTo(field2Values.size()));
|
||||
SearchResponse searchResponse = mock(SearchResponse.class);
|
||||
|
@ -544,4 +645,70 @@ public class DataFrameDataExtractorTests extends ESTestCase {
|
|||
return searchResponse;
|
||||
}
|
||||
}
|
||||
|
||||
private static class CategoricalPreProcessor implements PreProcessor {
|
||||
|
||||
private final List<String> inputFields;
|
||||
private final List<String> outputFields;
|
||||
|
||||
CategoricalPreProcessor(String inputField, String outputField) {
|
||||
this.inputFields = Arrays.asList(inputField);
|
||||
this.outputFields = Arrays.asList(outputField);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> inputFields() {
|
||||
return inputFields;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> outputFields() {
|
||||
return outputFields;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(Map<String, Object> fields) {
|
||||
fields.put(outputFields.get(0), "dog");
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, String> reverseLookup() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isCustom() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getOutputFieldType(String outputField) {
|
||||
return "text";
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,10 +15,13 @@ import org.elasticsearch.test.ESTestCase;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
|
||||
|
@ -30,11 +33,14 @@ import java.util.HashMap;
|
|||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.hamcrest.Matchers.arrayContaining;
|
||||
import static org.hamcrest.Matchers.arrayContainingInAnyOrder;
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
@ -929,12 +935,23 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
assertThat(e.getMessage(), equalTo("analyzed_fields must not include or exclude object fields: [object_field]"));
|
||||
}
|
||||
|
||||
private static FieldCapabilitiesResponse simpleFieldResponse() {
|
||||
return new MockFieldCapsResponseBuilder()
|
||||
.addAggregatableField("field_11", "float")
|
||||
.addNonAggregatableField("field_21", "float")
|
||||
.addAggregatableField("field_21.child", "float")
|
||||
.addNonAggregatableField("field_31", "float")
|
||||
.addAggregatableField("field_31.child", "float")
|
||||
.addNonAggregatableField("object_field", "object")
|
||||
.build();
|
||||
}
|
||||
|
||||
public void testDetect_GivenAnalyzedFieldExcludesObjectField() {
|
||||
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
|
||||
.addAggregatableField("float_field", "float")
|
||||
.addNonAggregatableField("object_field", "object").build();
|
||||
|
||||
analyzedFields = new FetchSourceContext(true, null, new String[] { "object_field" });
|
||||
analyzedFields = new FetchSourceContext(true, null, new String[]{"object_field"});
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap());
|
||||
|
@ -943,6 +960,177 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
assertThat(e.getMessage(), equalTo("analyzed_fields must not include or exclude object fields: [object_field]"));
|
||||
}
|
||||
|
||||
public void testDetect_givenFeatureProcessorsFailures_ResultsField() {
|
||||
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("ml.result", "foo"))),
|
||||
100,
|
||||
fieldCapabilities,
|
||||
Collections.emptyMap());
|
||||
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||
assertThat(ex.getMessage(),
|
||||
containsString("fields contained in results field [ml] cannot be used in a feature_processor"));
|
||||
}
|
||||
|
||||
public void testDetect_givenFeatureProcessorsFailures_Objects() {
|
||||
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("object_field", "foo"))),
|
||||
100,
|
||||
fieldCapabilities,
|
||||
Collections.emptyMap());
|
||||
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||
assertThat(ex.getMessage(),
|
||||
containsString("fields for feature_processors must not be objects"));
|
||||
}
|
||||
|
||||
public void testDetect_givenFeatureProcessorsFailures_ReservedFields() {
|
||||
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("_id", "foo"))),
|
||||
100,
|
||||
fieldCapabilities,
|
||||
Collections.emptyMap());
|
||||
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||
assertThat(ex.getMessage(),
|
||||
containsString("the following fields cannot be used in feature_processors"));
|
||||
}
|
||||
|
||||
public void testDetect_givenFeatureProcessorsFailures_MissingFieldFromIndex() {
|
||||
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("bar", "foo"))),
|
||||
100,
|
||||
fieldCapabilities,
|
||||
Collections.emptyMap());
|
||||
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||
assertThat(ex.getMessage(),
|
||||
containsString("the fields [bar] were not found in the field capabilities of the source indices"));
|
||||
}
|
||||
|
||||
public void testDetect_givenFeatureProcessorsFailures_UsingRequiredField() {
|
||||
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("field_31", "foo"))),
|
||||
100,
|
||||
fieldCapabilities,
|
||||
Collections.emptyMap());
|
||||
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||
assertThat(ex.getMessage(),
|
||||
containsString("required analysis fields [field_31] cannot be used in a feature_processor"));
|
||||
}
|
||||
|
||||
public void testDetect_givenFeatureProcessorsFailures_BadSourceFiltering() {
|
||||
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
|
||||
sourceFiltering = new FetchSourceContext(true, null, new String[]{"field_1*"});
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("field_11", "foo"))),
|
||||
100,
|
||||
fieldCapabilities,
|
||||
Collections.emptyMap());
|
||||
|
||||
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||
assertThat(ex.getMessage(),
|
||||
containsString("fields [field_11] required by field_processors are not included in source filtering."));
|
||||
}
|
||||
|
||||
public void testDetect_givenFeatureProcessorsFailures_MissingAnalyzedField() {
|
||||
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
|
||||
analyzedFields = new FetchSourceContext(true, null, new String[]{"field_1*"});
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("field_11", "foo"))),
|
||||
100,
|
||||
fieldCapabilities,
|
||||
Collections.emptyMap());
|
||||
|
||||
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||
assertThat(ex.getMessage(),
|
||||
containsString("fields [field_11] required by field_processors are not included in the analyzed_fields"));
|
||||
}
|
||||
|
||||
public void testDetect_givenFeatureProcessorsFailures_RequiredMultiFields() {
|
||||
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("field_31.child", "foo"))),
|
||||
100,
|
||||
fieldCapabilities,
|
||||
Collections.emptyMap());
|
||||
|
||||
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||
assertThat(ex.getMessage(),
|
||||
containsString("feature_processors cannot be applied to required fields for analysis; "));
|
||||
|
||||
extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
buildRegressionConfig("field_31.child", Arrays.asList(buildPreProcessor("field_31", "foo"))),
|
||||
100,
|
||||
fieldCapabilities,
|
||||
Collections.emptyMap());
|
||||
|
||||
ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||
assertThat(ex.getMessage(),
|
||||
containsString("feature_processors cannot be applied to required fields for analysis; "));
|
||||
}
|
||||
|
||||
public void testDetect_givenFeatureProcessorsFailures_BothMultiFields() {
|
||||
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
buildRegressionConfig("field_31",
|
||||
Arrays.asList(
|
||||
buildPreProcessor("field_21", "foo"),
|
||||
buildPreProcessor("field_21.child", "bar")
|
||||
)),
|
||||
100,
|
||||
fieldCapabilities,
|
||||
Collections.emptyMap());
|
||||
|
||||
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||
assertThat(ex.getMessage(),
|
||||
containsString("feature_processors refer to both multi-field "));
|
||||
}
|
||||
|
||||
public void testDetect_givenFeatureProcessorsFailures_DuplicateOutputFields() {
|
||||
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
buildRegressionConfig("field_31",
|
||||
Arrays.asList(
|
||||
buildPreProcessor("field_11", "foo"),
|
||||
buildPreProcessor("field_21", "foo")
|
||||
)),
|
||||
100,
|
||||
fieldCapabilities,
|
||||
Collections.emptyMap());
|
||||
|
||||
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||
assertThat(ex.getMessage(),
|
||||
containsString("feature_processors must define unique output field names; duplicate fields [foo]"));
|
||||
}
|
||||
|
||||
public void testDetect_withFeatureProcessors() {
|
||||
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
|
||||
.addAggregatableField("field_11", "float")
|
||||
.addAggregatableField("field_21", "float")
|
||||
.addNonAggregatableField("field_31", "float")
|
||||
.addAggregatableField("field_31.child", "float")
|
||||
.addNonAggregatableField("object_field", "object")
|
||||
.build();
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
buildRegressionConfig("field_11",
|
||||
Arrays.asList(buildPreProcessor("field_31", "foo", "bar"))),
|
||||
100,
|
||||
fieldCapabilities,
|
||||
Collections.emptyMap());
|
||||
|
||||
ExtractedFields extracted = extractedFieldsDetector.detect().v1();
|
||||
|
||||
assertThat(extracted.getProcessedFieldInputs(), containsInAnyOrder("field_31"));
|
||||
assertThat(extracted.getAllFields().stream().map(ExtractedField::getName).collect(Collectors.toSet()),
|
||||
containsInAnyOrder("field_11", "field_21", "field_31"));
|
||||
assertThat(extracted.getSourceFields(), arrayContainingInAnyOrder("field_31"));
|
||||
assertThat(extracted.getDocValueFields().stream().map(ExtractedField::getName).collect(Collectors.toSet()),
|
||||
containsInAnyOrder("field_21", "field_11"));
|
||||
assertThat(extracted.getProcessedFields(), hasSize(1));
|
||||
}
|
||||
|
||||
private DataFrameAnalyticsConfig buildOutlierDetectionConfig() {
|
||||
return new DataFrameAnalyticsConfig.Builder()
|
||||
.setId("foo")
|
||||
|
@ -954,13 +1142,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
}
|
||||
|
||||
private DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable) {
|
||||
return new DataFrameAnalyticsConfig.Builder()
|
||||
.setId("foo")
|
||||
.setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null, sourceFiltering))
|
||||
.setDest(new DataFrameAnalyticsDest(DEST_INDEX, RESULTS_FIELD))
|
||||
.setAnalyzedFields(analyzedFields)
|
||||
.setAnalysis(new Regression(dependentVariable))
|
||||
.build();
|
||||
return buildRegressionConfig(dependentVariable, Collections.emptyList());
|
||||
}
|
||||
|
||||
private DataFrameAnalyticsConfig buildClassificationConfig(String dependentVariable) {
|
||||
|
@ -972,6 +1154,29 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
.build();
|
||||
}
|
||||
|
||||
private DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable, List<PreProcessor> featureprocessors) {
|
||||
return new DataFrameAnalyticsConfig.Builder()
|
||||
.setId("foo")
|
||||
.setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null, sourceFiltering))
|
||||
.setDest(new DataFrameAnalyticsDest(DEST_INDEX, RESULTS_FIELD))
|
||||
.setAnalyzedFields(analyzedFields)
|
||||
.setAnalysis(new Regression(dependentVariable,
|
||||
BoostedTreeParams.builder().build(),
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
featureprocessors))
|
||||
.build();
|
||||
}
|
||||
|
||||
private static PreProcessor buildPreProcessor(String inputField, String... outputFields) {
|
||||
return new OneHotEncoding(inputField,
|
||||
Arrays.stream(outputFields).collect(Collectors.toMap((s) -> randomAlphaOfLength(10), Function.identity())),
|
||||
true);
|
||||
}
|
||||
|
||||
/**
|
||||
* We assert each field individually to get useful error messages in case of failure
|
||||
*/
|
||||
|
@ -987,6 +1192,23 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testDetect_givenFeatureProcessorsFailures_DuplicateOutputFieldsWithUnProcessedField() {
|
||||
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
buildRegressionConfig("field_31",
|
||||
Arrays.asList(
|
||||
buildPreProcessor("field_11", "field_21")
|
||||
)),
|
||||
100,
|
||||
fieldCapabilities,
|
||||
Collections.emptyMap());
|
||||
|
||||
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||
assertThat(ex.getMessage(),
|
||||
containsString(
|
||||
"feature_processors output fields must not include non-processed analysis fields; duplicate fields [field_21]"));
|
||||
}
|
||||
|
||||
private static class MockFieldCapsResponseBuilder {
|
||||
|
||||
private final Map<String, Map<String, FieldCapabilities>> fieldCaps = new HashMap<>();
|
||||
|
|
|
@ -80,6 +80,7 @@ public class InferenceRunnerTests extends ESTestCase {
|
|||
public void testInferTestDocs() {
|
||||
ExtractedFields extractedFields = new ExtractedFields(
|
||||
Collections.singletonList(new SourceField("key", Collections.singleton("integer"))),
|
||||
Collections.emptyList(),
|
||||
Collections.emptyMap());
|
||||
|
||||
Map<String, Object> doc1 = new HashMap<>();
|
||||
|
|
|
@ -63,7 +63,9 @@ public class AnalyticsProcessConfigTests extends ESTestCase {
|
|||
public void testToXContent_GivenOutlierDetection() throws IOException {
|
||||
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
|
||||
new DocValueField("field_1", Collections.singleton("double")),
|
||||
new DocValueField("field_2", Collections.singleton("float"))), Collections.emptyMap());
|
||||
new DocValueField("field_2", Collections.singleton("float"))),
|
||||
Collections.emptyList(),
|
||||
Collections.emptyMap());
|
||||
DataFrameAnalysis analysis = new OutlierDetection.Builder().build();
|
||||
|
||||
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
|
||||
|
@ -82,7 +84,9 @@ public class AnalyticsProcessConfigTests extends ESTestCase {
|
|||
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
|
||||
new DocValueField("field_1", Collections.singleton("double")),
|
||||
new DocValueField("field_2", Collections.singleton("float")),
|
||||
new DocValueField("test_dep_var", Collections.singleton("keyword"))), Collections.emptyMap());
|
||||
new DocValueField("test_dep_var", Collections.singleton("keyword"))),
|
||||
Collections.emptyList(),
|
||||
Collections.emptyMap());
|
||||
DataFrameAnalysis analysis = new Regression("test_dep_var");
|
||||
|
||||
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
|
||||
|
@ -103,7 +107,9 @@ public class AnalyticsProcessConfigTests extends ESTestCase {
|
|||
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
|
||||
new DocValueField("field_1", Collections.singleton("double")),
|
||||
new DocValueField("field_2", Collections.singleton("float")),
|
||||
new DocValueField("test_dep_var", Collections.singleton("keyword"))), Collections.singletonMap("test_dep_var", 5L));
|
||||
new DocValueField("test_dep_var", Collections.singleton("keyword"))),
|
||||
Collections.emptyList(),
|
||||
Collections.singletonMap("test_dep_var", 5L));
|
||||
DataFrameAnalysis analysis = new Classification("test_dep_var");
|
||||
|
||||
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
|
||||
|
@ -126,7 +132,9 @@ public class AnalyticsProcessConfigTests extends ESTestCase {
|
|||
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
|
||||
new DocValueField("field_1", Collections.singleton("double")),
|
||||
new DocValueField("field_2", Collections.singleton("float")),
|
||||
new DocValueField("test_dep_var", Collections.singleton("integer"))), Collections.singletonMap("test_dep_var", 8L));
|
||||
new DocValueField("test_dep_var", Collections.singleton("integer"))),
|
||||
Collections.emptyList(),
|
||||
Collections.singletonMap("test_dep_var", 8L));
|
||||
DataFrameAnalysis analysis = new Classification("test_dep_var");
|
||||
|
||||
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
|
||||
|
|
|
@ -105,7 +105,9 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
|
|||
OutlierDetectionTests.createRandom()).build();
|
||||
dataExtractor = mock(DataFrameDataExtractor.class);
|
||||
when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS));
|
||||
when(dataExtractor.getExtractedFields()).thenReturn(new ExtractedFields(Collections.emptyList(), Collections.emptyMap()));
|
||||
when(dataExtractor.getExtractedFields()).thenReturn(new ExtractedFields(Collections.emptyList(),
|
||||
Collections.emptyList(),
|
||||
Collections.emptyMap()));
|
||||
dataExtractorFactory = mock(DataFrameDataExtractorFactory.class);
|
||||
when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor);
|
||||
when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class));
|
||||
|
|
|
@ -314,6 +314,6 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
|||
trainedModelProvider,
|
||||
auditor,
|
||||
statsPersister,
|
||||
new ExtractedFields(fieldNames, Collections.emptyMap()));
|
||||
new ExtractedFields(fieldNames, Collections.emptyList(), Collections.emptyMap()));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -144,7 +144,7 @@ public class ChunkedTrainedModelPersisterTests extends ESTestCase {
|
|||
analyticsConfig,
|
||||
auditor,
|
||||
(unused)->{},
|
||||
new ExtractedFields(fieldNames, Collections.emptyMap()));
|
||||
new ExtractedFields(fieldNames, Collections.emptyList(), Collections.emptyMap()));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ import java.util.Collections;
|
|||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import java.util.TreeSet;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
@ -31,8 +32,10 @@ public class ExtractedFieldsTests extends ESTestCase {
|
|||
ExtractedField scriptField2 = new ScriptField("scripted2");
|
||||
ExtractedField sourceField1 = new SourceField("src1", Collections.singleton("text"));
|
||||
ExtractedField sourceField2 = new SourceField("src2", Collections.singleton("text"));
|
||||
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
|
||||
docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2), Collections.emptyMap());
|
||||
ExtractedFields extractedFields = new ExtractedFields(
|
||||
Arrays.asList(docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2),
|
||||
Collections.emptyList(),
|
||||
Collections.emptyMap());
|
||||
|
||||
assertThat(extractedFields.getAllFields().size(), equalTo(6));
|
||||
assertThat(extractedFields.getDocValueFields().stream().map(ExtractedField::getName).toArray(String[]::new),
|
||||
|
@ -53,8 +56,11 @@ public class ExtractedFieldsTests extends ESTestCase {
|
|||
when(fieldCapabilitiesResponse.getField("value")).thenReturn(valueCaps);
|
||||
when(fieldCapabilitiesResponse.getField("airline")).thenReturn(airlineCaps);
|
||||
|
||||
ExtractedFields extractedFields = ExtractedFields.build(Arrays.asList("time", "value", "airline", "airport"),
|
||||
new HashSet<>(Collections.singletonList("airport")), fieldCapabilitiesResponse, Collections.emptyMap());
|
||||
ExtractedFields extractedFields = ExtractedFields.build(new TreeSet<>(Arrays.asList("time", "value", "airline", "airport")),
|
||||
new HashSet<>(Collections.singletonList("airport")),
|
||||
fieldCapabilitiesResponse,
|
||||
Collections.emptyMap(),
|
||||
Collections.emptyList());
|
||||
|
||||
assertThat(extractedFields.getDocValueFields().size(), equalTo(2));
|
||||
assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("time"));
|
||||
|
@ -76,8 +82,8 @@ public class ExtractedFieldsTests extends ESTestCase {
|
|||
when(fieldCapabilitiesResponse.getField("airport")).thenReturn(text);
|
||||
when(fieldCapabilitiesResponse.getField("airport.keyword")).thenReturn(keyword);
|
||||
|
||||
ExtractedFields extractedFields = ExtractedFields.build(Arrays.asList("airline.text", "airport.keyword"),
|
||||
Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap());
|
||||
ExtractedFields extractedFields = ExtractedFields.build(new TreeSet<>(Arrays.asList("airline.text", "airport.keyword")),
|
||||
Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap(), Collections.emptyList());
|
||||
|
||||
assertThat(extractedFields.getDocValueFields().size(), equalTo(1));
|
||||
assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("airport.keyword"));
|
||||
|
@ -112,14 +118,18 @@ public class ExtractedFieldsTests extends ESTestCase {
|
|||
assertThat(mapped.getName(), equalTo(aBool.getName()));
|
||||
assertThat(mapped.getMethod(), equalTo(aBool.getMethod()));
|
||||
assertThat(mapped.supportsFromSource(), is(false));
|
||||
expectThrows(UnsupportedOperationException.class, () -> mapped.newFromSource());
|
||||
expectThrows(UnsupportedOperationException.class, mapped::newFromSource);
|
||||
}
|
||||
|
||||
public void testBuildGivenFieldWithoutMappings() {
|
||||
FieldCapabilitiesResponse fieldCapabilitiesResponse = mock(FieldCapabilitiesResponse.class);
|
||||
|
||||
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> ExtractedFields.build(
|
||||
Collections.singletonList("value"), Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap()));
|
||||
Collections.singleton("value"),
|
||||
Collections.emptySet(),
|
||||
fieldCapabilitiesResponse,
|
||||
Collections.emptyMap(),
|
||||
Collections.emptyList()));
|
||||
assertThat(e.getMessage(), equalTo("cannot retrieve field [value] because it has no mappings"));
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.extractor;
|
||||
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.hamcrest.Matchers.arrayContaining;
|
||||
import static org.hamcrest.Matchers.emptyArray;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasItems;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class ProcessedFieldTests extends ESTestCase {
|
||||
|
||||
public void testOneHotGetters() {
|
||||
String inputField = "foo";
|
||||
ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz"));
|
||||
assertThat(processedField.getInputFieldNames(), hasItems(inputField));
|
||||
assertThat(processedField.getOutputFieldNames(), hasItems("bar_column", "baz_column"));
|
||||
assertThat(processedField.getOutputFieldType("bar_column"), equalTo(Collections.singleton("integer")));
|
||||
assertThat(processedField.getOutputFieldType("baz_column"), equalTo(Collections.singleton("integer")));
|
||||
assertThat(processedField.getProcessorName(), equalTo(OneHotEncoding.NAME.getPreferredName()));
|
||||
}
|
||||
|
||||
public void testMissingExtractor() {
|
||||
String inputField = "foo";
|
||||
ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz"));
|
||||
assertThat(processedField.value(makeHit(), (s) -> null), emptyArray());
|
||||
}
|
||||
|
||||
public void testMissingInputValues() {
|
||||
String inputField = "foo";
|
||||
ExtractedField extractedField = makeExtractedField(new Object[0]);
|
||||
ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz"));
|
||||
assertThat(processedField.value(makeHit(), (s) -> extractedField), arrayContaining(is(nullValue()), is(nullValue())));
|
||||
}
|
||||
|
||||
public void testProcessedField() {
|
||||
ProcessedField processedField = new ProcessedField(makePreProcessor("foo", "bar", "baz"));
|
||||
assertThat(processedField.value(makeHit(), (s) -> makeExtractedField(new Object[] { "bar" })), arrayContaining(1, 0));
|
||||
assertThat(processedField.value(makeHit(), (s) -> makeExtractedField(new Object[] { "baz" })), arrayContaining(0, 1));
|
||||
}
|
||||
|
||||
private static PreProcessor makePreProcessor(String inputField, String... expectedExtractedValues) {
|
||||
return new OneHotEncoding(inputField,
|
||||
Arrays.stream(expectedExtractedValues).collect(Collectors.toMap(Function.identity(), (s) -> s + "_column")),
|
||||
true);
|
||||
}
|
||||
|
||||
private static ExtractedField makeExtractedField(Object[] value) {
|
||||
ExtractedField extractedField = mock(ExtractedField.class);
|
||||
when(extractedField.value(any())).thenReturn(value);
|
||||
return extractedField;
|
||||
}
|
||||
|
||||
private static SearchHit makeHit() {
|
||||
return new SearchHitBuilder(42).addField("a_keyword", "bar").build();
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue