feature_processors allow users to create custom features from individual document fields. These `feature_processors` are the same object as the trained model's pre_processors. They are passed to the native process and the native process then appends them to the pre_processor array in the inference model. closes https://github.com/elastic/elasticsearch/issues/59327
This commit is contained in:
parent
d1b60269f4
commit
8f302282f4
|
@ -15,10 +15,14 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.index.mapper.FieldAliasMapper;
|
import org.elasticsearch.index.mapper.FieldAliasMapper;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -46,6 +50,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
||||||
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
||||||
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
||||||
|
public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
|
||||||
|
|
||||||
private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1";
|
private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1";
|
||||||
|
|
||||||
|
@ -59,6 +64,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
*/
|
*/
|
||||||
public static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;
|
public static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
|
private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
|
||||||
ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>(
|
ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>(
|
||||||
NAME.getPreferredName(),
|
NAME.getPreferredName(),
|
||||||
|
@ -70,7 +76,8 @@ public class Classification implements DataFrameAnalysis {
|
||||||
(ClassAssignmentObjective) a[8],
|
(ClassAssignmentObjective) a[8],
|
||||||
(Integer) a[9],
|
(Integer) a[9],
|
||||||
(Double) a[10],
|
(Double) a[10],
|
||||||
(Long) a[11]));
|
(Long) a[11],
|
||||||
|
(List<PreProcessor>) a[12]));
|
||||||
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
|
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
|
||||||
BoostedTreeParams.declareFields(parser);
|
BoostedTreeParams.declareFields(parser);
|
||||||
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||||
|
@ -78,6 +85,12 @@ public class Classification implements DataFrameAnalysis {
|
||||||
parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
|
parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
|
||||||
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
|
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
|
||||||
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
|
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
|
||||||
|
parser.declareNamedObjects(optionalConstructorArg(),
|
||||||
|
(p, c, n) -> lenient ?
|
||||||
|
p.namedObject(LenientlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)) :
|
||||||
|
p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)),
|
||||||
|
(classification) -> {/*TODO should we throw if this is not set?*/},
|
||||||
|
FEATURE_PROCESSORS);
|
||||||
return parser;
|
return parser;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,6 +132,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
private final int numTopClasses;
|
private final int numTopClasses;
|
||||||
private final double trainingPercent;
|
private final double trainingPercent;
|
||||||
private final long randomizeSeed;
|
private final long randomizeSeed;
|
||||||
|
private final List<PreProcessor> featureProcessors;
|
||||||
|
|
||||||
public Classification(String dependentVariable,
|
public Classification(String dependentVariable,
|
||||||
BoostedTreeParams boostedTreeParams,
|
BoostedTreeParams boostedTreeParams,
|
||||||
|
@ -126,7 +140,8 @@ public class Classification implements DataFrameAnalysis {
|
||||||
@Nullable ClassAssignmentObjective classAssignmentObjective,
|
@Nullable ClassAssignmentObjective classAssignmentObjective,
|
||||||
@Nullable Integer numTopClasses,
|
@Nullable Integer numTopClasses,
|
||||||
@Nullable Double trainingPercent,
|
@Nullable Double trainingPercent,
|
||||||
@Nullable Long randomizeSeed) {
|
@Nullable Long randomizeSeed,
|
||||||
|
@Nullable List<PreProcessor> featureProcessors) {
|
||||||
if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) {
|
if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) {
|
||||||
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName());
|
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName());
|
||||||
}
|
}
|
||||||
|
@ -141,10 +156,11 @@ public class Classification implements DataFrameAnalysis {
|
||||||
this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses;
|
this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses;
|
||||||
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
|
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
|
||||||
this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed;
|
this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed;
|
||||||
|
this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Classification(String dependentVariable) {
|
public Classification(String dependentVariable) {
|
||||||
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null);
|
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Classification(StreamInput in) throws IOException {
|
public Classification(StreamInput in) throws IOException {
|
||||||
|
@ -163,6 +179,11 @@ public class Classification implements DataFrameAnalysis {
|
||||||
} else {
|
} else {
|
||||||
randomizeSeed = Randomness.get().nextLong();
|
randomizeSeed = Randomness.get().nextLong();
|
||||||
}
|
}
|
||||||
|
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||||
|
featureProcessors = Collections.unmodifiableList(in.readNamedWriteableList(PreProcessor.class));
|
||||||
|
} else {
|
||||||
|
featureProcessors = Collections.emptyList();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getDependentVariable() {
|
public String getDependentVariable() {
|
||||||
|
@ -193,6 +214,10 @@ public class Classification implements DataFrameAnalysis {
|
||||||
return randomizeSeed;
|
return randomizeSeed;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public List<PreProcessor> getFeatureProcessors() {
|
||||||
|
return featureProcessors;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getWriteableName() {
|
public String getWriteableName() {
|
||||||
return NAME.getPreferredName();
|
return NAME.getPreferredName();
|
||||||
|
@ -211,6 +236,9 @@ public class Classification implements DataFrameAnalysis {
|
||||||
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
|
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
|
||||||
out.writeOptionalLong(randomizeSeed);
|
out.writeOptionalLong(randomizeSeed);
|
||||||
}
|
}
|
||||||
|
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||||
|
out.writeNamedWriteableList(featureProcessors);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -229,6 +257,9 @@ public class Classification implements DataFrameAnalysis {
|
||||||
if (version.onOrAfter(Version.V_7_6_0)) {
|
if (version.onOrAfter(Version.V_7_6_0)) {
|
||||||
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
|
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
|
||||||
}
|
}
|
||||||
|
if (featureProcessors.isEmpty() == false) {
|
||||||
|
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
|
||||||
|
}
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
@ -249,6 +280,10 @@ public class Classification implements DataFrameAnalysis {
|
||||||
}
|
}
|
||||||
params.put(NUM_CLASSES, fieldInfo.getCardinality(dependentVariable));
|
params.put(NUM_CLASSES, fieldInfo.getCardinality(dependentVariable));
|
||||||
params.put(TRAINING_PERCENT.getPreferredName(), trainingPercent);
|
params.put(TRAINING_PERCENT.getPreferredName(), trainingPercent);
|
||||||
|
if (featureProcessors.isEmpty() == false) {
|
||||||
|
params.put(FEATURE_PROCESSORS.getPreferredName(),
|
||||||
|
featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList()));
|
||||||
|
}
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -390,6 +425,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
||||||
&& Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
|
&& Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
|
||||||
&& Objects.equals(numTopClasses, that.numTopClasses)
|
&& Objects.equals(numTopClasses, that.numTopClasses)
|
||||||
|
&& Objects.equals(featureProcessors, that.featureProcessors)
|
||||||
&& trainingPercent == that.trainingPercent
|
&& trainingPercent == that.trainingPercent
|
||||||
&& randomizeSeed == that.randomizeSeed;
|
&& randomizeSeed == that.randomizeSeed;
|
||||||
}
|
}
|
||||||
|
@ -397,7 +433,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, classAssignmentObjective,
|
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, classAssignmentObjective,
|
||||||
numTopClasses, trainingPercent, randomizeSeed);
|
numTopClasses, trainingPercent, randomizeSeed, featureProcessors);
|
||||||
}
|
}
|
||||||
|
|
||||||
public enum ClassAssignmentObjective {
|
public enum ClassAssignmentObjective {
|
||||||
|
|
|
@ -15,9 +15,13 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -28,6 +32,7 @@ import java.util.Locale;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||||
|
@ -42,12 +47,14 @@ public class Regression implements DataFrameAnalysis {
|
||||||
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
||||||
public static final ParseField LOSS_FUNCTION = new ParseField("loss_function");
|
public static final ParseField LOSS_FUNCTION = new ParseField("loss_function");
|
||||||
public static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter");
|
public static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter");
|
||||||
|
public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
|
||||||
|
|
||||||
private static final String STATE_DOC_ID_SUFFIX = "_regression_state#1";
|
private static final String STATE_DOC_ID_SUFFIX = "_regression_state#1";
|
||||||
|
|
||||||
private static final ConstructingObjectParser<Regression, Void> LENIENT_PARSER = createParser(true);
|
private static final ConstructingObjectParser<Regression, Void> LENIENT_PARSER = createParser(true);
|
||||||
private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false);
|
private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false);
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
private static ConstructingObjectParser<Regression, Void> createParser(boolean lenient) {
|
private static ConstructingObjectParser<Regression, Void> createParser(boolean lenient) {
|
||||||
ConstructingObjectParser<Regression, Void> parser = new ConstructingObjectParser<>(
|
ConstructingObjectParser<Regression, Void> parser = new ConstructingObjectParser<>(
|
||||||
NAME.getPreferredName(),
|
NAME.getPreferredName(),
|
||||||
|
@ -59,7 +66,8 @@ public class Regression implements DataFrameAnalysis {
|
||||||
(Double) a[8],
|
(Double) a[8],
|
||||||
(Long) a[9],
|
(Long) a[9],
|
||||||
(LossFunction) a[10],
|
(LossFunction) a[10],
|
||||||
(Double) a[11]));
|
(Double) a[11],
|
||||||
|
(List<PreProcessor>) a[12]));
|
||||||
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
|
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
|
||||||
BoostedTreeParams.declareFields(parser);
|
BoostedTreeParams.declareFields(parser);
|
||||||
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||||
|
@ -67,6 +75,12 @@ public class Regression implements DataFrameAnalysis {
|
||||||
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
|
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
|
||||||
parser.declareString(optionalConstructorArg(), LossFunction::fromString, LOSS_FUNCTION);
|
parser.declareString(optionalConstructorArg(), LossFunction::fromString, LOSS_FUNCTION);
|
||||||
parser.declareDouble(optionalConstructorArg(), LOSS_FUNCTION_PARAMETER);
|
parser.declareDouble(optionalConstructorArg(), LOSS_FUNCTION_PARAMETER);
|
||||||
|
parser.declareNamedObjects(optionalConstructorArg(),
|
||||||
|
(p, c, n) -> lenient ?
|
||||||
|
p.namedObject(LenientlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)) :
|
||||||
|
p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)),
|
||||||
|
(regression) -> {/*TODO should we throw if this is not set?*/},
|
||||||
|
FEATURE_PROCESSORS);
|
||||||
return parser;
|
return parser;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,6 +104,7 @@ public class Regression implements DataFrameAnalysis {
|
||||||
private final long randomizeSeed;
|
private final long randomizeSeed;
|
||||||
private final LossFunction lossFunction;
|
private final LossFunction lossFunction;
|
||||||
private final Double lossFunctionParameter;
|
private final Double lossFunctionParameter;
|
||||||
|
private final List<PreProcessor> featureProcessors;
|
||||||
|
|
||||||
public Regression(String dependentVariable,
|
public Regression(String dependentVariable,
|
||||||
BoostedTreeParams boostedTreeParams,
|
BoostedTreeParams boostedTreeParams,
|
||||||
|
@ -97,7 +112,8 @@ public class Regression implements DataFrameAnalysis {
|
||||||
@Nullable Double trainingPercent,
|
@Nullable Double trainingPercent,
|
||||||
@Nullable Long randomizeSeed,
|
@Nullable Long randomizeSeed,
|
||||||
@Nullable LossFunction lossFunction,
|
@Nullable LossFunction lossFunction,
|
||||||
@Nullable Double lossFunctionParameter) {
|
@Nullable Double lossFunctionParameter,
|
||||||
|
@Nullable List<PreProcessor> featureProcessors) {
|
||||||
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
|
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
|
||||||
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
|
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
|
||||||
}
|
}
|
||||||
|
@ -112,10 +128,11 @@ public class Regression implements DataFrameAnalysis {
|
||||||
throw ExceptionsHelper.badRequestException("[{}] must be a positive double", LOSS_FUNCTION_PARAMETER.getPreferredName());
|
throw ExceptionsHelper.badRequestException("[{}] must be a positive double", LOSS_FUNCTION_PARAMETER.getPreferredName());
|
||||||
}
|
}
|
||||||
this.lossFunctionParameter = lossFunctionParameter;
|
this.lossFunctionParameter = lossFunctionParameter;
|
||||||
|
this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Regression(String dependentVariable) {
|
public Regression(String dependentVariable) {
|
||||||
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null);
|
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Regression(StreamInput in) throws IOException {
|
public Regression(StreamInput in) throws IOException {
|
||||||
|
@ -136,6 +153,11 @@ public class Regression implements DataFrameAnalysis {
|
||||||
lossFunction = LossFunction.MSE;
|
lossFunction = LossFunction.MSE;
|
||||||
lossFunctionParameter = null;
|
lossFunctionParameter = null;
|
||||||
}
|
}
|
||||||
|
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||||
|
featureProcessors = Collections.unmodifiableList(in.readNamedWriteableList(PreProcessor.class));
|
||||||
|
} else {
|
||||||
|
featureProcessors = Collections.emptyList();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getDependentVariable() {
|
public String getDependentVariable() {
|
||||||
|
@ -166,6 +188,10 @@ public class Regression implements DataFrameAnalysis {
|
||||||
return lossFunctionParameter;
|
return lossFunctionParameter;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public List<PreProcessor> getFeatureProcessors() {
|
||||||
|
return featureProcessors;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getWriteableName() {
|
public String getWriteableName() {
|
||||||
return NAME.getPreferredName();
|
return NAME.getPreferredName();
|
||||||
|
@ -184,6 +210,9 @@ public class Regression implements DataFrameAnalysis {
|
||||||
out.writeEnum(lossFunction);
|
out.writeEnum(lossFunction);
|
||||||
out.writeOptionalDouble(lossFunctionParameter);
|
out.writeOptionalDouble(lossFunctionParameter);
|
||||||
}
|
}
|
||||||
|
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||||
|
out.writeNamedWriteableList(featureProcessors);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -204,6 +233,9 @@ public class Regression implements DataFrameAnalysis {
|
||||||
if (lossFunctionParameter != null) {
|
if (lossFunctionParameter != null) {
|
||||||
builder.field(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
|
builder.field(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
|
||||||
}
|
}
|
||||||
|
if (featureProcessors.isEmpty() == false) {
|
||||||
|
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
|
||||||
|
}
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
@ -221,6 +253,10 @@ public class Regression implements DataFrameAnalysis {
|
||||||
if (lossFunctionParameter != null) {
|
if (lossFunctionParameter != null) {
|
||||||
params.put(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
|
params.put(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
|
||||||
}
|
}
|
||||||
|
if (featureProcessors.isEmpty() == false) {
|
||||||
|
params.put(FEATURE_PROCESSORS.getPreferredName(),
|
||||||
|
featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList()));
|
||||||
|
}
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -304,13 +340,14 @@ public class Regression implements DataFrameAnalysis {
|
||||||
&& trainingPercent == that.trainingPercent
|
&& trainingPercent == that.trainingPercent
|
||||||
&& randomizeSeed == that.randomizeSeed
|
&& randomizeSeed == that.randomizeSeed
|
||||||
&& lossFunction == that.lossFunction
|
&& lossFunction == that.lossFunction
|
||||||
|
&& Objects.equals(featureProcessors, that.featureProcessors)
|
||||||
&& Objects.equals(lossFunctionParameter, that.lossFunctionParameter);
|
&& Objects.equals(lossFunctionParameter, that.lossFunctionParameter);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction,
|
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction,
|
||||||
lossFunctionParameter);
|
lossFunctionParameter, featureProcessors);
|
||||||
}
|
}
|
||||||
|
|
||||||
public enum LossFunction {
|
public enum LossFunction {
|
||||||
|
|
|
@ -57,23 +57,23 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
|
||||||
|
|
||||||
// PreProcessing Lenient
|
// PreProcessing Lenient
|
||||||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, OneHotEncoding.NAME,
|
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, OneHotEncoding.NAME,
|
||||||
OneHotEncoding::fromXContentLenient));
|
(p, c) -> OneHotEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
|
||||||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, TargetMeanEncoding.NAME,
|
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, TargetMeanEncoding.NAME,
|
||||||
TargetMeanEncoding::fromXContentLenient));
|
(p, c) -> TargetMeanEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
|
||||||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, FrequencyEncoding.NAME,
|
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, FrequencyEncoding.NAME,
|
||||||
FrequencyEncoding::fromXContentLenient));
|
(p, c) -> FrequencyEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
|
||||||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, CustomWordEmbedding.NAME,
|
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, CustomWordEmbedding.NAME,
|
||||||
CustomWordEmbedding::fromXContentLenient));
|
(p, c) -> CustomWordEmbedding.fromXContentLenient(p)));
|
||||||
|
|
||||||
// PreProcessing Strict
|
// PreProcessing Strict
|
||||||
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, OneHotEncoding.NAME,
|
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, OneHotEncoding.NAME,
|
||||||
OneHotEncoding::fromXContentStrict));
|
(p, c) -> OneHotEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
|
||||||
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, TargetMeanEncoding.NAME,
|
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, TargetMeanEncoding.NAME,
|
||||||
TargetMeanEncoding::fromXContentStrict));
|
(p, c) -> TargetMeanEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
|
||||||
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, FrequencyEncoding.NAME,
|
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, FrequencyEncoding.NAME,
|
||||||
FrequencyEncoding::fromXContentStrict));
|
(p, c) -> FrequencyEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
|
||||||
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, CustomWordEmbedding.NAME,
|
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, CustomWordEmbedding.NAME,
|
||||||
CustomWordEmbedding::fromXContentStrict));
|
(p, c) -> CustomWordEmbedding.fromXContentStrict(p)));
|
||||||
|
|
||||||
// Model Lenient
|
// Model Lenient
|
||||||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentLenient));
|
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentLenient));
|
||||||
|
|
|
@ -56,8 +56,8 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable, Acco
|
||||||
TRAINED_MODEL);
|
TRAINED_MODEL);
|
||||||
parser.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors,
|
parser.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors,
|
||||||
(p, c, n) -> ignoreUnknownFields ?
|
(p, c, n) -> ignoreUnknownFields ?
|
||||||
p.namedObject(LenientlyParsedPreProcessor.class, n, null) :
|
p.namedObject(LenientlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT) :
|
||||||
p.namedObject(StrictlyParsedPreProcessor.class, n, null),
|
p.namedObject(StrictlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT),
|
||||||
(trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true),
|
(trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true),
|
||||||
PREPROCESSORS);
|
PREPROCESSORS);
|
||||||
return parser;
|
return parser;
|
||||||
|
|
|
@ -50,15 +50,15 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
|
||||||
public static final ParseField EMBEDDING_WEIGHTS = new ParseField("embedding_weights");
|
public static final ParseField EMBEDDING_WEIGHTS = new ParseField("embedding_weights");
|
||||||
public static final ParseField EMBEDDING_QUANT_SCALES = new ParseField("embedding_quant_scales");
|
public static final ParseField EMBEDDING_QUANT_SCALES = new ParseField("embedding_quant_scales");
|
||||||
|
|
||||||
public static final ConstructingObjectParser<CustomWordEmbedding, Void> STRICT_PARSER = createParser(false);
|
private static final ConstructingObjectParser<CustomWordEmbedding, PreProcessorParseContext> STRICT_PARSER = createParser(false);
|
||||||
public static final ConstructingObjectParser<CustomWordEmbedding, Void> LENIENT_PARSER = createParser(true);
|
private static final ConstructingObjectParser<CustomWordEmbedding, PreProcessorParseContext> LENIENT_PARSER = createParser(true);
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
private static ConstructingObjectParser<CustomWordEmbedding, Void> createParser(boolean lenient) {
|
private static ConstructingObjectParser<CustomWordEmbedding, PreProcessorParseContext> createParser(boolean lenient) {
|
||||||
ConstructingObjectParser<CustomWordEmbedding, Void> parser = new ConstructingObjectParser<>(
|
ConstructingObjectParser<CustomWordEmbedding, PreProcessorParseContext> parser = new ConstructingObjectParser<>(
|
||||||
NAME.getPreferredName(),
|
NAME.getPreferredName(),
|
||||||
lenient,
|
lenient,
|
||||||
a -> new CustomWordEmbedding((short[][])a[0], (byte[][])a[1], (String)a[2], (String)a[3]));
|
(a, c) -> new CustomWordEmbedding((short[][])a[0], (byte[][])a[1], (String)a[2], (String)a[3]));
|
||||||
|
|
||||||
parser.declareField(ConstructingObjectParser.constructorArg(),
|
parser.declareField(ConstructingObjectParser.constructorArg(),
|
||||||
(p, c) -> {
|
(p, c) -> {
|
||||||
|
@ -123,11 +123,11 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
|
||||||
}
|
}
|
||||||
|
|
||||||
public static CustomWordEmbedding fromXContentStrict(XContentParser parser) {
|
public static CustomWordEmbedding fromXContentStrict(XContentParser parser) {
|
||||||
return STRICT_PARSER.apply(parser, null);
|
return STRICT_PARSER.apply(parser, PreProcessorParseContext.DEFAULT);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static CustomWordEmbedding fromXContentLenient(XContentParser parser) {
|
public static CustomWordEmbedding fromXContentLenient(XContentParser parser) {
|
||||||
return LENIENT_PARSER.apply(parser, null);
|
return LENIENT_PARSER.apply(parser, PreProcessorParseContext.DEFAULT);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final int CONCAT_LAYER_SIZE = 80;
|
private static final int CONCAT_LAYER_SIZE = 80;
|
||||||
|
@ -256,6 +256,11 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getOutputFieldType(String outputField) {
|
||||||
|
return "dense_vector";
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long ramBytesUsed() {
|
public long ramBytesUsed() {
|
||||||
long size = SHALLOW_SIZE;
|
long size = SHALLOW_SIZE;
|
||||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -37,15 +38,18 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
|
||||||
public static final ParseField FREQUENCY_MAP = new ParseField("frequency_map");
|
public static final ParseField FREQUENCY_MAP = new ParseField("frequency_map");
|
||||||
public static final ParseField CUSTOM = new ParseField("custom");
|
public static final ParseField CUSTOM = new ParseField("custom");
|
||||||
|
|
||||||
public static final ConstructingObjectParser<FrequencyEncoding, Void> STRICT_PARSER = createParser(false);
|
private static final ConstructingObjectParser<FrequencyEncoding, PreProcessorParseContext> STRICT_PARSER = createParser(false);
|
||||||
public static final ConstructingObjectParser<FrequencyEncoding, Void> LENIENT_PARSER = createParser(true);
|
private static final ConstructingObjectParser<FrequencyEncoding, PreProcessorParseContext> LENIENT_PARSER = createParser(true);
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
private static ConstructingObjectParser<FrequencyEncoding, Void> createParser(boolean lenient) {
|
private static ConstructingObjectParser<FrequencyEncoding, PreProcessorParseContext> createParser(boolean lenient) {
|
||||||
ConstructingObjectParser<FrequencyEncoding, Void> parser = new ConstructingObjectParser<>(
|
ConstructingObjectParser<FrequencyEncoding, PreProcessorParseContext> parser = new ConstructingObjectParser<>(
|
||||||
NAME.getPreferredName(),
|
NAME.getPreferredName(),
|
||||||
lenient,
|
lenient,
|
||||||
a -> new FrequencyEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Boolean)a[3]));
|
(a, c) -> new FrequencyEncoding((String)a[0],
|
||||||
|
(String)a[1],
|
||||||
|
(Map<String, Double>)a[2],
|
||||||
|
a[3] == null ? c.isCustomByDefault() : (Boolean)a[3]));
|
||||||
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
||||||
parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
|
parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
|
||||||
parser.declareObject(ConstructingObjectParser.constructorArg(),
|
parser.declareObject(ConstructingObjectParser.constructorArg(),
|
||||||
|
@ -55,12 +59,12 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
|
||||||
return parser;
|
return parser;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static FrequencyEncoding fromXContentStrict(XContentParser parser) {
|
public static FrequencyEncoding fromXContentStrict(XContentParser parser, PreProcessorParseContext context) {
|
||||||
return STRICT_PARSER.apply(parser, null);
|
return STRICT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static FrequencyEncoding fromXContentLenient(XContentParser parser) {
|
public static FrequencyEncoding fromXContentLenient(XContentParser parser, PreProcessorParseContext context) {
|
||||||
return LENIENT_PARSER.apply(parser, null);
|
return LENIENT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context);
|
||||||
}
|
}
|
||||||
|
|
||||||
private final String field;
|
private final String field;
|
||||||
|
@ -117,6 +121,11 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
|
||||||
return custom;
|
return custom;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getOutputFieldType(String outputField) {
|
||||||
|
return NumberFieldMapper.NumberType.DOUBLE.typeName();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getName() {
|
public String getName() {
|
||||||
return NAME.getPreferredName();
|
return NAME.getPreferredName();
|
||||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -36,27 +37,29 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
|
||||||
public static final ParseField HOT_MAP = new ParseField("hot_map");
|
public static final ParseField HOT_MAP = new ParseField("hot_map");
|
||||||
public static final ParseField CUSTOM = new ParseField("custom");
|
public static final ParseField CUSTOM = new ParseField("custom");
|
||||||
|
|
||||||
public static final ConstructingObjectParser<OneHotEncoding, Void> STRICT_PARSER = createParser(false);
|
private static final ConstructingObjectParser<OneHotEncoding, PreProcessorParseContext> STRICT_PARSER = createParser(false);
|
||||||
public static final ConstructingObjectParser<OneHotEncoding, Void> LENIENT_PARSER = createParser(true);
|
private static final ConstructingObjectParser<OneHotEncoding, PreProcessorParseContext> LENIENT_PARSER = createParser(true);
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
private static ConstructingObjectParser<OneHotEncoding, Void> createParser(boolean lenient) {
|
private static ConstructingObjectParser<OneHotEncoding, PreProcessorParseContext> createParser(boolean lenient) {
|
||||||
ConstructingObjectParser<OneHotEncoding, Void> parser = new ConstructingObjectParser<>(
|
ConstructingObjectParser<OneHotEncoding, PreProcessorParseContext> parser = new ConstructingObjectParser<>(
|
||||||
NAME.getPreferredName(),
|
NAME.getPreferredName(),
|
||||||
lenient,
|
lenient,
|
||||||
a -> new OneHotEncoding((String)a[0], (Map<String, String>)a[1], (Boolean)a[2]));
|
(a, c) -> new OneHotEncoding((String)a[0],
|
||||||
|
(Map<String, String>)a[1],
|
||||||
|
a[2] == null ? c.isCustomByDefault() : (Boolean)a[2]));
|
||||||
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
||||||
parser.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.mapStrings(), HOT_MAP);
|
parser.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.mapStrings(), HOT_MAP);
|
||||||
parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
|
parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
|
||||||
return parser;
|
return parser;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static OneHotEncoding fromXContentStrict(XContentParser parser) {
|
public static OneHotEncoding fromXContentStrict(XContentParser parser, PreProcessorParseContext context) {
|
||||||
return STRICT_PARSER.apply(parser, null);
|
return STRICT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static OneHotEncoding fromXContentLenient(XContentParser parser) {
|
public static OneHotEncoding fromXContentLenient(XContentParser parser, PreProcessorParseContext context) {
|
||||||
return LENIENT_PARSER.apply(parser, null);
|
return LENIENT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context);
|
||||||
}
|
}
|
||||||
|
|
||||||
private final String field;
|
private final String field;
|
||||||
|
@ -103,6 +106,11 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
|
||||||
return custom;
|
return custom;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getOutputFieldType(String outputField) {
|
||||||
|
return NumberFieldMapper.NumberType.INTEGER.typeName();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getName() {
|
public String getName() {
|
||||||
return NAME.getPreferredName();
|
return NAME.getPreferredName();
|
||||||
|
@ -124,8 +132,9 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
|
||||||
if (value == null) {
|
if (value == null) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
final String stringValue = value.toString();
|
||||||
hotMap.forEach((val, col) -> {
|
hotMap.forEach((val, col) -> {
|
||||||
int encoding = value.toString().equals(val) ? 1 : 0;
|
int encoding = stringValue.equals(val) ? 1 : 0;
|
||||||
fields.put(col, encoding);
|
fields.put(col, encoding);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,18 @@ import java.util.Map;
|
||||||
*/
|
*/
|
||||||
public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accountable {
|
public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accountable {
|
||||||
|
|
||||||
|
class PreProcessorParseContext {
|
||||||
|
public static final PreProcessorParseContext DEFAULT = new PreProcessorParseContext(false);
|
||||||
|
final boolean defaultIsCustomValue;
|
||||||
|
public PreProcessorParseContext(boolean defaultIsCustomValue) {
|
||||||
|
this.defaultIsCustomValue = defaultIsCustomValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isCustomByDefault() {
|
||||||
|
return defaultIsCustomValue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The expected input fields
|
* The expected input fields
|
||||||
*/
|
*/
|
||||||
|
@ -48,4 +60,6 @@ public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accou
|
||||||
*/
|
*/
|
||||||
boolean isCustom();
|
boolean isCustom();
|
||||||
|
|
||||||
|
String getOutputFieldType(String outputField);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -37,15 +38,19 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
|
||||||
public static final ParseField DEFAULT_VALUE = new ParseField("default_value");
|
public static final ParseField DEFAULT_VALUE = new ParseField("default_value");
|
||||||
public static final ParseField CUSTOM = new ParseField("custom");
|
public static final ParseField CUSTOM = new ParseField("custom");
|
||||||
|
|
||||||
public static final ConstructingObjectParser<TargetMeanEncoding, Void> STRICT_PARSER = createParser(false);
|
private static final ConstructingObjectParser<TargetMeanEncoding, PreProcessorParseContext> STRICT_PARSER = createParser(false);
|
||||||
public static final ConstructingObjectParser<TargetMeanEncoding, Void> LENIENT_PARSER = createParser(true);
|
private static final ConstructingObjectParser<TargetMeanEncoding, PreProcessorParseContext> LENIENT_PARSER = createParser(true);
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
private static ConstructingObjectParser<TargetMeanEncoding, Void> createParser(boolean lenient) {
|
private static ConstructingObjectParser<TargetMeanEncoding, PreProcessorParseContext> createParser(boolean lenient) {
|
||||||
ConstructingObjectParser<TargetMeanEncoding, Void> parser = new ConstructingObjectParser<>(
|
ConstructingObjectParser<TargetMeanEncoding, PreProcessorParseContext> parser = new ConstructingObjectParser<>(
|
||||||
NAME.getPreferredName(),
|
NAME.getPreferredName(),
|
||||||
lenient,
|
lenient,
|
||||||
a -> new TargetMeanEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Double)a[3], (Boolean)a[4]));
|
(a, c) -> new TargetMeanEncoding((String)a[0],
|
||||||
|
(String)a[1],
|
||||||
|
(Map<String, Double>)a[2],
|
||||||
|
(Double)a[3],
|
||||||
|
a[4] == null ? c.isCustomByDefault() : (Boolean)a[4]));
|
||||||
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
||||||
parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
|
parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
|
||||||
parser.declareObject(ConstructingObjectParser.constructorArg(),
|
parser.declareObject(ConstructingObjectParser.constructorArg(),
|
||||||
|
@ -56,12 +61,12 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
|
||||||
return parser;
|
return parser;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static TargetMeanEncoding fromXContentStrict(XContentParser parser) {
|
public static TargetMeanEncoding fromXContentStrict(XContentParser parser, PreProcessorParseContext context) {
|
||||||
return STRICT_PARSER.apply(parser, null);
|
return STRICT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static TargetMeanEncoding fromXContentLenient(XContentParser parser) {
|
public static TargetMeanEncoding fromXContentLenient(XContentParser parser, PreProcessorParseContext context) {
|
||||||
return LENIENT_PARSER.apply(parser, null);
|
return LENIENT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context);
|
||||||
}
|
}
|
||||||
|
|
||||||
private final String field;
|
private final String field;
|
||||||
|
@ -128,6 +133,11 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
|
||||||
return custom;
|
return custom;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getOutputFieldType(String outputField) {
|
||||||
|
return NumberFieldMapper.NumberType.DOUBLE.typeName();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getName() {
|
public String getName() {
|
||||||
return NAME.getPreferredName();
|
return NAME.getPreferredName();
|
||||||
|
|
|
@ -41,7 +41,7 @@ public class InferenceDefinition {
|
||||||
(p, c, n) -> p.namedObject(InferenceModel.class, n, null),
|
(p, c, n) -> p.namedObject(InferenceModel.class, n, null),
|
||||||
TRAINED_MODEL);
|
TRAINED_MODEL);
|
||||||
PARSER.declareNamedObjects(InferenceDefinition.Builder::setPreProcessors,
|
PARSER.declareNamedObjects(InferenceDefinition.Builder::setPreProcessors,
|
||||||
(p, c, n) -> p.namedObject(LenientlyParsedPreProcessor.class, n, null),
|
(p, c, n) -> p.namedObject(LenientlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT),
|
||||||
(trainedModelDefBuilder) -> {},
|
(trainedModelDefBuilder) -> {},
|
||||||
PREPROCESSORS);
|
PREPROCESSORS);
|
||||||
}
|
}
|
||||||
|
|
|
@ -327,12 +327,14 @@ public final class ReservedFieldNames {
|
||||||
Regression.LOSS_FUNCTION_PARAMETER.getPreferredName(),
|
Regression.LOSS_FUNCTION_PARAMETER.getPreferredName(),
|
||||||
Regression.PREDICTION_FIELD_NAME.getPreferredName(),
|
Regression.PREDICTION_FIELD_NAME.getPreferredName(),
|
||||||
Regression.TRAINING_PERCENT.getPreferredName(),
|
Regression.TRAINING_PERCENT.getPreferredName(),
|
||||||
|
Regression.FEATURE_PROCESSORS.getPreferredName(),
|
||||||
Classification.NAME.getPreferredName(),
|
Classification.NAME.getPreferredName(),
|
||||||
Classification.DEPENDENT_VARIABLE.getPreferredName(),
|
Classification.DEPENDENT_VARIABLE.getPreferredName(),
|
||||||
Classification.PREDICTION_FIELD_NAME.getPreferredName(),
|
Classification.PREDICTION_FIELD_NAME.getPreferredName(),
|
||||||
Classification.CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(),
|
Classification.CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(),
|
||||||
Classification.NUM_TOP_CLASSES.getPreferredName(),
|
Classification.NUM_TOP_CLASSES.getPreferredName(),
|
||||||
Classification.TRAINING_PERCENT.getPreferredName(),
|
Classification.TRAINING_PERCENT.getPreferredName(),
|
||||||
|
Classification.FEATURE_PROCESSORS.getPreferredName(),
|
||||||
BoostedTreeParams.LAMBDA.getPreferredName(),
|
BoostedTreeParams.LAMBDA.getPreferredName(),
|
||||||
BoostedTreeParams.GAMMA.getPreferredName(),
|
BoostedTreeParams.GAMMA.getPreferredName(),
|
||||||
BoostedTreeParams.ETA.getPreferredName(),
|
BoostedTreeParams.ETA.getPreferredName(),
|
||||||
|
|
|
@ -34,6 +34,9 @@
|
||||||
"feature_bag_fraction" : {
|
"feature_bag_fraction" : {
|
||||||
"type" : "double"
|
"type" : "double"
|
||||||
},
|
},
|
||||||
|
"feature_processors": {
|
||||||
|
"enabled": false
|
||||||
|
},
|
||||||
"gamma" : {
|
"gamma" : {
|
||||||
"type" : "double"
|
"type" : "double"
|
||||||
},
|
},
|
||||||
|
@ -84,6 +87,9 @@
|
||||||
"feature_bag_fraction" : {
|
"feature_bag_fraction" : {
|
||||||
"type" : "double"
|
"type" : "double"
|
||||||
},
|
},
|
||||||
|
"feature_processors": {
|
||||||
|
"enabled": false
|
||||||
|
},
|
||||||
"gamma" : {
|
"gamma" : {
|
||||||
"type" : "double"
|
"type" : "double"
|
||||||
},
|
},
|
||||||
|
|
|
@ -16,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction.Respon
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
@ -28,6 +29,7 @@ public class GetDataFrameAnalyticsActionResponseTests extends AbstractWireSerial
|
||||||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||||
namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables());
|
namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables());
|
||||||
namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
|
namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
|
||||||
|
namedWriteables.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||||
return new NamedWriteableRegistry(namedWriteables);
|
return new NamedWriteableRegistry(namedWriteables);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -36,6 +38,7 @@ public class GetDataFrameAnalyticsActionResponseTests extends AbstractWireSerial
|
||||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||||
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
||||||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||||
|
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||||
return new NamedXContentRegistry(namedXContent);
|
return new NamedXContentRegistry(namedXContent);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -44,6 +45,7 @@ public class PutDataFrameAnalyticsActionRequestTests extends AbstractSerializing
|
||||||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||||
namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables());
|
namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables());
|
||||||
namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
|
namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
|
||||||
|
namedWriteables.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||||
return new NamedWriteableRegistry(namedWriteables);
|
return new NamedWriteableRegistry(namedWriteables);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,6 +54,7 @@ public class PutDataFrameAnalyticsActionRequestTests extends AbstractSerializing
|
||||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||||
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
||||||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||||
|
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||||
return new NamedXContentRegistry(namedXContent);
|
return new NamedXContentRegistry(namedXContent);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||||
import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction.Response;
|
import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction.Response;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
@ -25,6 +26,7 @@ public class PutDataFrameAnalyticsActionResponseTests extends AbstractWireSerial
|
||||||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||||
namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables());
|
namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables());
|
||||||
namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
|
namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
|
||||||
|
namedWriteables.addAll(new MlInferenceNamedXContentProvider() .getNamedWriteables());
|
||||||
return new NamedWriteableRegistry(namedWriteables);
|
return new NamedWriteableRegistry(namedWriteables);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -41,6 +41,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.RegressionTests;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.RegressionTests;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
|
||||||
|
@ -78,6 +79,7 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
|
||||||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||||
namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables());
|
namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables());
|
||||||
namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
|
namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
|
||||||
|
namedWriteables.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||||
return new NamedWriteableRegistry(namedWriteables);
|
return new NamedWriteableRegistry(namedWriteables);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,6 +88,7 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
|
||||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||||
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
||||||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||||
|
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||||
return new NamedXContentRegistry(namedXContent);
|
return new NamedXContentRegistry(namedXContent);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -144,14 +147,16 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
|
||||||
bwcRegression.getTrainingPercent(),
|
bwcRegression.getTrainingPercent(),
|
||||||
42L,
|
42L,
|
||||||
bwcRegression.getLossFunction(),
|
bwcRegression.getLossFunction(),
|
||||||
bwcRegression.getLossFunctionParameter());
|
bwcRegression.getLossFunctionParameter(),
|
||||||
|
bwcRegression.getFeatureProcessors());
|
||||||
testAnalysis = new Regression(testRegression.getDependentVariable(),
|
testAnalysis = new Regression(testRegression.getDependentVariable(),
|
||||||
testRegression.getBoostedTreeParams(),
|
testRegression.getBoostedTreeParams(),
|
||||||
testRegression.getPredictionFieldName(),
|
testRegression.getPredictionFieldName(),
|
||||||
testRegression.getTrainingPercent(),
|
testRegression.getTrainingPercent(),
|
||||||
42L,
|
42L,
|
||||||
testRegression.getLossFunction(),
|
testRegression.getLossFunction(),
|
||||||
testRegression.getLossFunctionParameter());
|
testRegression.getLossFunctionParameter(),
|
||||||
|
bwcRegression.getFeatureProcessors());
|
||||||
} else {
|
} else {
|
||||||
Classification testClassification = (Classification)testInstance.getAnalysis();
|
Classification testClassification = (Classification)testInstance.getAnalysis();
|
||||||
Classification bwcClassification = (Classification)bwcSerializedObject.getAnalysis();
|
Classification bwcClassification = (Classification)bwcSerializedObject.getAnalysis();
|
||||||
|
@ -161,14 +166,16 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
|
||||||
bwcClassification.getClassAssignmentObjective(),
|
bwcClassification.getClassAssignmentObjective(),
|
||||||
bwcClassification.getNumTopClasses(),
|
bwcClassification.getNumTopClasses(),
|
||||||
bwcClassification.getTrainingPercent(),
|
bwcClassification.getTrainingPercent(),
|
||||||
42L);
|
42L,
|
||||||
|
bwcClassification.getFeatureProcessors());
|
||||||
testAnalysis = new Classification(testClassification.getDependentVariable(),
|
testAnalysis = new Classification(testClassification.getDependentVariable(),
|
||||||
testClassification.getBoostedTreeParams(),
|
testClassification.getBoostedTreeParams(),
|
||||||
testClassification.getPredictionFieldName(),
|
testClassification.getPredictionFieldName(),
|
||||||
testClassification.getClassAssignmentObjective(),
|
testClassification.getClassAssignmentObjective(),
|
||||||
testClassification.getNumTopClasses(),
|
testClassification.getNumTopClasses(),
|
||||||
testClassification.getTrainingPercent(),
|
testClassification.getTrainingPercent(),
|
||||||
42L);
|
42L,
|
||||||
|
testClassification.getFeatureProcessors());
|
||||||
}
|
}
|
||||||
super.assertOnBWCObject(new DataFrameAnalyticsConfig.Builder(bwcSerializedObject)
|
super.assertOnBWCObject(new DataFrameAnalyticsConfig.Builder(bwcSerializedObject)
|
||||||
.setAnalysis(bwcAnalysis)
|
.setAnalysis(bwcAnalysis)
|
||||||
|
|
|
@ -8,25 +8,41 @@ package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
||||||
import org.elasticsearch.ElasticsearchStatusException;
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
import org.elasticsearch.Version;
|
import org.elasticsearch.Version;
|
||||||
import org.elasticsearch.common.Strings;
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.common.bytes.BytesArray;
|
||||||
|
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||||
import org.elasticsearch.common.io.stream.Writeable;
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
|
import org.elasticsearch.common.settings.Settings;
|
||||||
|
import org.elasticsearch.common.xcontent.DeprecationHandler;
|
||||||
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
import org.elasticsearch.common.xcontent.ToXContent;
|
import org.elasticsearch.common.xcontent.ToXContent;
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentType;
|
||||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||||
import org.elasticsearch.index.mapper.BooleanFieldMapper;
|
import org.elasticsearch.index.mapper.BooleanFieldMapper;
|
||||||
import org.elasticsearch.index.mapper.KeywordFieldMapper;
|
import org.elasticsearch.index.mapper.KeywordFieldMapper;
|
||||||
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||||
|
import org.elasticsearch.search.SearchModule;
|
||||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.allOf;
|
import static org.hamcrest.Matchers.allOf;
|
||||||
import static org.hamcrest.Matchers.containsString;
|
import static org.hamcrest.Matchers.containsString;
|
||||||
|
@ -55,6 +71,21 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
||||||
return createRandom();
|
return createRandom();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected NamedXContentRegistry xContentRegistry() {
|
||||||
|
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||||
|
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||||
|
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||||
|
return new NamedXContentRegistry(namedXContent);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||||
|
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
|
||||||
|
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||||
|
return new NamedWriteableRegistry(entries);
|
||||||
|
}
|
||||||
|
|
||||||
public static Classification createRandom() {
|
public static Classification createRandom() {
|
||||||
String dependentVariableName = randomAlphaOfLength(10);
|
String dependentVariableName = randomAlphaOfLength(10);
|
||||||
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
|
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
|
||||||
|
@ -65,7 +96,14 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
||||||
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
|
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
|
||||||
Long randomizeSeed = randomBoolean() ? null : randomLong();
|
Long randomizeSeed = randomBoolean() ? null : randomLong();
|
||||||
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
|
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
|
||||||
numTopClasses, trainingPercent, randomizeSeed);
|
numTopClasses, trainingPercent, randomizeSeed,
|
||||||
|
randomBoolean() ?
|
||||||
|
null :
|
||||||
|
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(true),
|
||||||
|
OneHotEncodingTests.createRandom(true),
|
||||||
|
TargetMeanEncodingTests.createRandom(true)))
|
||||||
|
.limit(randomIntBetween(0, 5))
|
||||||
|
.collect(Collectors.toList()));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Classification mutateForVersion(Classification instance, Version version) {
|
public static Classification mutateForVersion(Classification instance, Version version) {
|
||||||
|
@ -75,7 +113,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
||||||
version.onOrAfter(Version.V_7_7_0) ? instance.getClassAssignmentObjective() : null,
|
version.onOrAfter(Version.V_7_7_0) ? instance.getClassAssignmentObjective() : null,
|
||||||
instance.getNumTopClasses(),
|
instance.getNumTopClasses(),
|
||||||
instance.getTrainingPercent(),
|
instance.getTrainingPercent(),
|
||||||
instance.getRandomizeSeed());
|
instance.getRandomizeSeed(),
|
||||||
|
version.onOrAfter(Version.V_7_10_0) ? instance.getFeatureProcessors() : Collections.emptyList());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -91,14 +130,16 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
||||||
bwcSerializedObject.getClassAssignmentObjective(),
|
bwcSerializedObject.getClassAssignmentObjective(),
|
||||||
bwcSerializedObject.getNumTopClasses(),
|
bwcSerializedObject.getNumTopClasses(),
|
||||||
bwcSerializedObject.getTrainingPercent(),
|
bwcSerializedObject.getTrainingPercent(),
|
||||||
42L);
|
42L,
|
||||||
|
bwcSerializedObject.getFeatureProcessors());
|
||||||
Classification newInstance = new Classification(testInstance.getDependentVariable(),
|
Classification newInstance = new Classification(testInstance.getDependentVariable(),
|
||||||
testInstance.getBoostedTreeParams(),
|
testInstance.getBoostedTreeParams(),
|
||||||
testInstance.getPredictionFieldName(),
|
testInstance.getPredictionFieldName(),
|
||||||
testInstance.getClassAssignmentObjective(),
|
testInstance.getClassAssignmentObjective(),
|
||||||
testInstance.getNumTopClasses(),
|
testInstance.getNumTopClasses(),
|
||||||
testInstance.getTrainingPercent(),
|
testInstance.getTrainingPercent(),
|
||||||
42L);
|
42L,
|
||||||
|
testInstance.getFeatureProcessors());
|
||||||
super.assertOnBWCObject(newBwc, newInstance, version);
|
super.assertOnBWCObject(newBwc, newInstance, version);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -107,87 +148,138 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
||||||
return Classification::new;
|
return Classification::new;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testDeserialization() throws IOException {
|
||||||
|
String toDeserialize = "{\n" +
|
||||||
|
" \"dependent_variable\": \"FlightDelayMin\",\n" +
|
||||||
|
" \"feature_processors\": [\n" +
|
||||||
|
" {\n" +
|
||||||
|
" \"one_hot_encoding\": {\n" +
|
||||||
|
" \"field\": \"OriginWeather\",\n" +
|
||||||
|
" \"hot_map\": {\n" +
|
||||||
|
" \"sunny_col\": \"Sunny\",\n" +
|
||||||
|
" \"clear_col\": \"Clear\",\n" +
|
||||||
|
" \"rainy_col\": \"Rain\"\n" +
|
||||||
|
" }\n" +
|
||||||
|
" }\n" +
|
||||||
|
" },\n" +
|
||||||
|
" {\n" +
|
||||||
|
" \"one_hot_encoding\": {\n" +
|
||||||
|
" \"field\": \"DestWeather\",\n" +
|
||||||
|
" \"hot_map\": {\n" +
|
||||||
|
" \"dest_sunny_col\": \"Sunny\",\n" +
|
||||||
|
" \"dest_clear_col\": \"Clear\",\n" +
|
||||||
|
" \"dest_rainy_col\": \"Rain\"\n" +
|
||||||
|
" }\n" +
|
||||||
|
" }\n" +
|
||||||
|
" },\n" +
|
||||||
|
" {\n" +
|
||||||
|
" \"frequency_encoding\": {\n" +
|
||||||
|
" \"field\": \"OriginWeather\",\n" +
|
||||||
|
" \"feature_name\": \"mean\",\n" +
|
||||||
|
" \"frequency_map\": {\n" +
|
||||||
|
" \"Sunny\": 0.8,\n" +
|
||||||
|
" \"Rain\": 0.2\n" +
|
||||||
|
" }\n" +
|
||||||
|
" }\n" +
|
||||||
|
" }\n" +
|
||||||
|
" ]\n" +
|
||||||
|
" }" +
|
||||||
|
"";
|
||||||
|
|
||||||
|
try(XContentParser parser = XContentHelper.createParser(xContentRegistry(),
|
||||||
|
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
|
||||||
|
new BytesArray(toDeserialize),
|
||||||
|
XContentType.JSON)) {
|
||||||
|
Classification parsed = Classification.fromXContent(parser, false);
|
||||||
|
assertThat(parsed.getDependentVariable(), equalTo("FlightDelayMin"));
|
||||||
|
for (PreProcessor preProcessor : parsed.getFeatureProcessors()) {
|
||||||
|
assertThat(preProcessor.isCustom(), is(true));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
|
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.999, randomLong()));
|
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.999, randomLong(), null));
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
|
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong()));
|
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong(), null));
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testConstructor_GivenNumTopClassesIsLessThanZero() {
|
public void testConstructor_GivenNumTopClassesIsLessThanZero() {
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong()));
|
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null));
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
|
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong()));
|
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null));
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testGetPredictionFieldName() {
|
public void testGetPredictionFieldName() {
|
||||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong());
|
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null);
|
||||||
assertThat(classification.getPredictionFieldName(), equalTo("result"));
|
assertThat(classification.getPredictionFieldName(), equalTo("result"));
|
||||||
|
|
||||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong());
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong(), null);
|
||||||
assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction"));
|
assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testClassAssignmentObjective() {
|
public void testClassAssignmentObjective() {
|
||||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
|
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
|
||||||
Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong());
|
Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong(), null);
|
||||||
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY));
|
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY));
|
||||||
|
|
||||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
|
||||||
Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong());
|
Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong(), null);
|
||||||
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
|
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
|
||||||
|
|
||||||
// class_assignment_objective == null, default applied
|
// class_assignment_objective == null, default applied
|
||||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong());
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
|
||||||
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
|
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testGetNumTopClasses() {
|
public void testGetNumTopClasses() {
|
||||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong());
|
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
|
||||||
assertThat(classification.getNumTopClasses(), equalTo(7));
|
assertThat(classification.getNumTopClasses(), equalTo(7));
|
||||||
|
|
||||||
// Boundary condition: num_top_classes == 0
|
// Boundary condition: num_top_classes == 0
|
||||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong());
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong(), null);
|
||||||
assertThat(classification.getNumTopClasses(), equalTo(0));
|
assertThat(classification.getNumTopClasses(), equalTo(0));
|
||||||
|
|
||||||
// Boundary condition: num_top_classes == 1000
|
// Boundary condition: num_top_classes == 1000
|
||||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong());
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong(), null);
|
||||||
assertThat(classification.getNumTopClasses(), equalTo(1000));
|
assertThat(classification.getNumTopClasses(), equalTo(1000));
|
||||||
|
|
||||||
// num_top_classes == null, default applied
|
// num_top_classes == null, default applied
|
||||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong());
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong(), null);
|
||||||
assertThat(classification.getNumTopClasses(), equalTo(2));
|
assertThat(classification.getNumTopClasses(), equalTo(2));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testGetTrainingPercent() {
|
public void testGetTrainingPercent() {
|
||||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong());
|
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null);
|
||||||
assertThat(classification.getTrainingPercent(), equalTo(50.0));
|
assertThat(classification.getTrainingPercent(), equalTo(50.0));
|
||||||
|
|
||||||
// Boundary condition: training_percent == 1.0
|
// Boundary condition: training_percent == 1.0
|
||||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong());
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong(), null);
|
||||||
assertThat(classification.getTrainingPercent(), equalTo(1.0));
|
assertThat(classification.getTrainingPercent(), equalTo(1.0));
|
||||||
|
|
||||||
// Boundary condition: training_percent == 100.0
|
// Boundary condition: training_percent == 100.0
|
||||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong());
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong(), null);
|
||||||
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
||||||
|
|
||||||
// training_percent == null, default applied
|
// training_percent == null, default applied
|
||||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong());
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong(), null);
|
||||||
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -233,6 +325,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
50.0,
|
50.0,
|
||||||
|
null,
|
||||||
null).getParams(fieldInfo),
|
null).getParams(fieldInfo),
|
||||||
equalTo(
|
equalTo(
|
||||||
org.elasticsearch.common.collect.Map.of(
|
org.elasticsearch.common.collect.Map.of(
|
||||||
|
|
|
@ -8,18 +8,35 @@ package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
||||||
import org.elasticsearch.ElasticsearchStatusException;
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
import org.elasticsearch.Version;
|
import org.elasticsearch.Version;
|
||||||
import org.elasticsearch.common.Strings;
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.common.bytes.BytesArray;
|
||||||
|
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||||
import org.elasticsearch.common.io.stream.Writeable;
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
|
import org.elasticsearch.common.settings.Settings;
|
||||||
|
import org.elasticsearch.common.xcontent.DeprecationHandler;
|
||||||
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
import org.elasticsearch.common.xcontent.ToXContent;
|
import org.elasticsearch.common.xcontent.ToXContent;
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentType;
|
||||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||||
|
import org.elasticsearch.search.SearchModule;
|
||||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.containsString;
|
import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.hamcrest.Matchers.empty;
|
import static org.hamcrest.Matchers.empty;
|
||||||
|
@ -45,6 +62,21 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
||||||
return createRandom();
|
return createRandom();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected NamedXContentRegistry xContentRegistry() {
|
||||||
|
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||||
|
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||||
|
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||||
|
return new NamedXContentRegistry(namedXContent);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||||
|
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
|
||||||
|
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||||
|
return new NamedWriteableRegistry(entries);
|
||||||
|
}
|
||||||
|
|
||||||
public static Regression createRandom() {
|
public static Regression createRandom() {
|
||||||
return createRandom(BoostedTreeParamsTests.createRandom());
|
return createRandom(BoostedTreeParamsTests.createRandom());
|
||||||
}
|
}
|
||||||
|
@ -57,7 +89,14 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
||||||
Regression.LossFunction lossFunction = randomBoolean() ? null : randomFrom(Regression.LossFunction.values());
|
Regression.LossFunction lossFunction = randomBoolean() ? null : randomFrom(Regression.LossFunction.values());
|
||||||
Double lossFunctionParameter = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, false);
|
Double lossFunctionParameter = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, false);
|
||||||
return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction,
|
return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction,
|
||||||
lossFunctionParameter);
|
lossFunctionParameter,
|
||||||
|
randomBoolean() ?
|
||||||
|
null :
|
||||||
|
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(true),
|
||||||
|
OneHotEncodingTests.createRandom(true),
|
||||||
|
TargetMeanEncodingTests.createRandom(true)))
|
||||||
|
.limit(randomIntBetween(0, 5))
|
||||||
|
.collect(Collectors.toList()));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Regression mutateForVersion(Regression instance, Version version) {
|
public static Regression mutateForVersion(Regression instance, Version version) {
|
||||||
|
@ -67,7 +106,8 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
||||||
instance.getTrainingPercent(),
|
instance.getTrainingPercent(),
|
||||||
instance.getRandomizeSeed(),
|
instance.getRandomizeSeed(),
|
||||||
version.onOrAfter(Version.V_7_8_0) ? instance.getLossFunction() : null,
|
version.onOrAfter(Version.V_7_8_0) ? instance.getLossFunction() : null,
|
||||||
version.onOrAfter(Version.V_7_8_0) ? instance.getLossFunctionParameter() : null);
|
version.onOrAfter(Version.V_7_8_0) ? instance.getLossFunctionParameter() : null,
|
||||||
|
version.onOrAfter(Version.V_7_10_0) ? instance.getFeatureProcessors() : Collections.emptyList());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -83,14 +123,16 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
||||||
bwcSerializedObject.getTrainingPercent(),
|
bwcSerializedObject.getTrainingPercent(),
|
||||||
42L,
|
42L,
|
||||||
bwcSerializedObject.getLossFunction(),
|
bwcSerializedObject.getLossFunction(),
|
||||||
bwcSerializedObject.getLossFunctionParameter());
|
bwcSerializedObject.getLossFunctionParameter(),
|
||||||
|
bwcSerializedObject.getFeatureProcessors());
|
||||||
Regression newInstance = new Regression(testInstance.getDependentVariable(),
|
Regression newInstance = new Regression(testInstance.getDependentVariable(),
|
||||||
testInstance.getBoostedTreeParams(),
|
testInstance.getBoostedTreeParams(),
|
||||||
testInstance.getPredictionFieldName(),
|
testInstance.getPredictionFieldName(),
|
||||||
testInstance.getTrainingPercent(),
|
testInstance.getTrainingPercent(),
|
||||||
42L,
|
42L,
|
||||||
testInstance.getLossFunction(),
|
testInstance.getLossFunction(),
|
||||||
testInstance.getLossFunctionParameter());
|
testInstance.getLossFunctionParameter(),
|
||||||
|
testInstance.getFeatureProcessors());
|
||||||
super.assertOnBWCObject(newBwc, newInstance, version);
|
super.assertOnBWCObject(newBwc, newInstance, version);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,56 +146,122 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
||||||
return Regression::new;
|
return Regression::new;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testDeserialization() throws IOException {
|
||||||
|
String toDeserialize = "{\n" +
|
||||||
|
" \"dependent_variable\": \"FlightDelayMin\",\n" +
|
||||||
|
" \"feature_processors\": [\n" +
|
||||||
|
" {\n" +
|
||||||
|
" \"one_hot_encoding\": {\n" +
|
||||||
|
" \"field\": \"OriginWeather\",\n" +
|
||||||
|
" \"hot_map\": {\n" +
|
||||||
|
" \"sunny_col\": \"Sunny\",\n" +
|
||||||
|
" \"clear_col\": \"Clear\",\n" +
|
||||||
|
" \"rainy_col\": \"Rain\"\n" +
|
||||||
|
" }\n" +
|
||||||
|
" }\n" +
|
||||||
|
" },\n" +
|
||||||
|
" {\n" +
|
||||||
|
" \"one_hot_encoding\": {\n" +
|
||||||
|
" \"field\": \"DestWeather\",\n" +
|
||||||
|
" \"hot_map\": {\n" +
|
||||||
|
" \"dest_sunny_col\": \"Sunny\",\n" +
|
||||||
|
" \"dest_clear_col\": \"Clear\",\n" +
|
||||||
|
" \"dest_rainy_col\": \"Rain\"\n" +
|
||||||
|
" }\n" +
|
||||||
|
" }\n" +
|
||||||
|
" },\n" +
|
||||||
|
" {\n" +
|
||||||
|
" \"frequency_encoding\": {\n" +
|
||||||
|
" \"field\": \"OriginWeather\",\n" +
|
||||||
|
" \"feature_name\": \"mean\",\n" +
|
||||||
|
" \"frequency_map\": {\n" +
|
||||||
|
" \"Sunny\": 0.8,\n" +
|
||||||
|
" \"Rain\": 0.2\n" +
|
||||||
|
" }\n" +
|
||||||
|
" }\n" +
|
||||||
|
" }\n" +
|
||||||
|
" ]\n" +
|
||||||
|
" }" +
|
||||||
|
"";
|
||||||
|
|
||||||
|
try(XContentParser parser = XContentHelper.createParser(xContentRegistry(),
|
||||||
|
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
|
||||||
|
new BytesArray(toDeserialize),
|
||||||
|
XContentType.JSON)) {
|
||||||
|
Regression parsed = Regression.fromXContent(parser, false);
|
||||||
|
assertThat(parsed.getDependentVariable(), equalTo("FlightDelayMin"));
|
||||||
|
for (PreProcessor preProcessor : parsed.getFeatureProcessors()) {
|
||||||
|
assertThat(preProcessor.isCustom(), is(true));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
|
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999, randomLong(), Regression.LossFunction.MSE, null));
|
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999, randomLong(), Regression.LossFunction.MSE, null, null));
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
|
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(), Regression.LossFunction.MSE, null));
|
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(), Regression.LossFunction.MSE, null, null));
|
||||||
|
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testConstructor_GivenLossFunctionParameterIsZero() {
|
public void testConstructor_GivenLossFunctionParameterIsZero() {
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, 0.0));
|
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, 0.0, null));
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double"));
|
assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testConstructor_GivenLossFunctionParameterIsNegative() {
|
public void testConstructor_GivenLossFunctionParameterIsNegative() {
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, -1.0));
|
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, -1.0, null));
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double"));
|
assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testGetPredictionFieldName() {
|
public void testGetPredictionFieldName() {
|
||||||
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong(), Regression.LossFunction.MSE, 1.0);
|
Regression regression = new Regression(
|
||||||
|
"foo",
|
||||||
|
BOOSTED_TREE_PARAMS,
|
||||||
|
"result",
|
||||||
|
50.0,
|
||||||
|
randomLong(),
|
||||||
|
Regression.LossFunction.MSE,
|
||||||
|
1.0,
|
||||||
|
null);
|
||||||
assertThat(regression.getPredictionFieldName(), equalTo("result"));
|
assertThat(regression.getPredictionFieldName(), equalTo("result"));
|
||||||
|
|
||||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong(), Regression.LossFunction.MSE, null);
|
regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong(), Regression.LossFunction.MSE, null, null);
|
||||||
assertThat(regression.getPredictionFieldName(), equalTo("foo_prediction"));
|
assertThat(regression.getPredictionFieldName(), equalTo("foo_prediction"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testGetTrainingPercent() {
|
public void testGetTrainingPercent() {
|
||||||
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong(), Regression.LossFunction.MSE, 1.0);
|
Regression regression = new Regression("foo",
|
||||||
|
BOOSTED_TREE_PARAMS,
|
||||||
|
"result",
|
||||||
|
50.0,
|
||||||
|
randomLong(),
|
||||||
|
Regression.LossFunction.MSE,
|
||||||
|
1.0,
|
||||||
|
null);
|
||||||
assertThat(regression.getTrainingPercent(), equalTo(50.0));
|
assertThat(regression.getTrainingPercent(), equalTo(50.0));
|
||||||
|
|
||||||
// Boundary condition: training_percent == 1.0
|
// Boundary condition: training_percent == 1.0
|
||||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong(), Regression.LossFunction.MSE, null);
|
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong(), Regression.LossFunction.MSE, null, null);
|
||||||
assertThat(regression.getTrainingPercent(), equalTo(1.0));
|
assertThat(regression.getTrainingPercent(), equalTo(1.0));
|
||||||
|
|
||||||
// Boundary condition: training_percent == 100.0
|
// Boundary condition: training_percent == 100.0
|
||||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, null);
|
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, null, null);
|
||||||
assertThat(regression.getTrainingPercent(), equalTo(100.0));
|
assertThat(regression.getTrainingPercent(), equalTo(100.0));
|
||||||
|
|
||||||
// training_percent == null, default applied
|
// training_percent == null, default applied
|
||||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong(), Regression.LossFunction.MSE, null);
|
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong(), Regression.LossFunction.MSE, null, null);
|
||||||
assertThat(regression.getTrainingPercent(), equalTo(100.0));
|
assertThat(regression.getTrainingPercent(), equalTo(100.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -165,6 +273,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
||||||
100.0,
|
100.0,
|
||||||
0L,
|
0L,
|
||||||
Regression.LossFunction.MSE,
|
Regression.LossFunction.MSE,
|
||||||
|
null,
|
||||||
null);
|
null);
|
||||||
|
|
||||||
Map<String, Object> params = regression.getParams(null);
|
Map<String, Object> params = regression.getParams(null);
|
||||||
|
@ -182,7 +291,9 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
||||||
|
|
||||||
Map<String, Object> params = regression.getParams(null);
|
Map<String, Object> params = regression.getParams(null);
|
||||||
|
|
||||||
int expectedParamsCount = 4 + (regression.getLossFunctionParameter() == null ? 0 : 1);
|
int expectedParamsCount = 4
|
||||||
|
+ (regression.getLossFunctionParameter() == null ? 0 : 1)
|
||||||
|
+ (regression.getFeatureProcessors().isEmpty() ? 0 : 1);
|
||||||
assertThat(params.size(), equalTo(expectedParamsCount));
|
assertThat(params.size(), equalTo(expectedParamsCount));
|
||||||
assertThat(params.get("dependent_variable"), equalTo(regression.getDependentVariable()));
|
assertThat(params.get("dependent_variable"), equalTo(regression.getDependentVariable()));
|
||||||
assertThat(params.get("prediction_field_name"), equalTo(regression.getPredictionFieldName()));
|
assertThat(params.get("prediction_field_name"), equalTo(regression.getPredictionFieldName()));
|
||||||
|
|
|
@ -24,7 +24,9 @@ public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected FrequencyEncoding doParseInstance(XContentParser parser) throws IOException {
|
protected FrequencyEncoding doParseInstance(XContentParser parser) throws IOException {
|
||||||
return lenient ? FrequencyEncoding.fromXContentLenient(parser) : FrequencyEncoding.fromXContentStrict(parser);
|
return lenient ?
|
||||||
|
FrequencyEncoding.fromXContentLenient(parser, PreProcessor.PreProcessorParseContext.DEFAULT) :
|
||||||
|
FrequencyEncoding.fromXContentStrict(parser, PreProcessor.PreProcessorParseContext.DEFAULT);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -33,6 +35,10 @@ public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding
|
||||||
}
|
}
|
||||||
|
|
||||||
public static FrequencyEncoding createRandom() {
|
public static FrequencyEncoding createRandom() {
|
||||||
|
return createRandom(randomBoolean() ? null : randomBoolean());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static FrequencyEncoding createRandom(Boolean isCustom) {
|
||||||
int valuesSize = randomIntBetween(1, 10);
|
int valuesSize = randomIntBetween(1, 10);
|
||||||
Map<String, Double> valueMap = new HashMap<>();
|
Map<String, Double> valueMap = new HashMap<>();
|
||||||
for (int i = 0; i < valuesSize; i++) {
|
for (int i = 0; i < valuesSize; i++) {
|
||||||
|
@ -41,7 +47,7 @@ public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding
|
||||||
return new FrequencyEncoding(randomAlphaOfLength(10),
|
return new FrequencyEncoding(randomAlphaOfLength(10),
|
||||||
randomAlphaOfLength(10),
|
randomAlphaOfLength(10),
|
||||||
valueMap,
|
valueMap,
|
||||||
randomBoolean() ? null : randomBoolean());
|
isCustom);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -24,7 +24,9 @@ public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected OneHotEncoding doParseInstance(XContentParser parser) throws IOException {
|
protected OneHotEncoding doParseInstance(XContentParser parser) throws IOException {
|
||||||
return lenient ? OneHotEncoding.fromXContentLenient(parser) : OneHotEncoding.fromXContentStrict(parser);
|
return lenient ?
|
||||||
|
OneHotEncoding.fromXContentLenient(parser, PreProcessor.PreProcessorParseContext.DEFAULT) :
|
||||||
|
OneHotEncoding.fromXContentStrict(parser, PreProcessor.PreProcessorParseContext.DEFAULT);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -33,6 +35,10 @@ public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static OneHotEncoding createRandom() {
|
public static OneHotEncoding createRandom() {
|
||||||
|
return createRandom(randomBoolean() ? randomBoolean() : null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static OneHotEncoding createRandom(Boolean isCustom) {
|
||||||
int valuesSize = randomIntBetween(1, 10);
|
int valuesSize = randomIntBetween(1, 10);
|
||||||
Map<String, String> valueMap = new HashMap<>();
|
Map<String, String> valueMap = new HashMap<>();
|
||||||
for (int i = 0; i < valuesSize; i++) {
|
for (int i = 0; i < valuesSize; i++) {
|
||||||
|
@ -40,7 +46,7 @@ public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
|
||||||
}
|
}
|
||||||
return new OneHotEncoding(randomAlphaOfLength(10),
|
return new OneHotEncoding(randomAlphaOfLength(10),
|
||||||
valueMap,
|
valueMap,
|
||||||
randomBoolean() ? randomBoolean() : null);
|
isCustom);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -24,7 +24,9 @@ public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncodi
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected TargetMeanEncoding doParseInstance(XContentParser parser) throws IOException {
|
protected TargetMeanEncoding doParseInstance(XContentParser parser) throws IOException {
|
||||||
return lenient ? TargetMeanEncoding.fromXContentLenient(parser) : TargetMeanEncoding.fromXContentStrict(parser);
|
return lenient ?
|
||||||
|
TargetMeanEncoding.fromXContentLenient(parser, PreProcessor.PreProcessorParseContext.DEFAULT) :
|
||||||
|
TargetMeanEncoding.fromXContentStrict(parser, PreProcessor.PreProcessorParseContext.DEFAULT);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -32,7 +34,12 @@ public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncodi
|
||||||
return createRandom();
|
return createRandom();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public static TargetMeanEncoding createRandom() {
|
public static TargetMeanEncoding createRandom() {
|
||||||
|
return createRandom(randomBoolean() ? randomBoolean() : null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static TargetMeanEncoding createRandom(Boolean isCustom) {
|
||||||
int valuesSize = randomIntBetween(1, 10);
|
int valuesSize = randomIntBetween(1, 10);
|
||||||
Map<String, Double> valueMap = new HashMap<>();
|
Map<String, Double> valueMap = new HashMap<>();
|
||||||
for (int i = 0; i < valuesSize; i++) {
|
for (int i = 0; i < valuesSize; i++) {
|
||||||
|
@ -42,7 +49,7 @@ public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncodi
|
||||||
randomAlphaOfLength(10),
|
randomAlphaOfLength(10),
|
||||||
valueMap,
|
valueMap,
|
||||||
randomDoubleBetween(0.0, 1.0, false),
|
randomDoubleBetween(0.0, 1.0, false),
|
||||||
randomBoolean() ? randomBoolean() : null);
|
isCustom);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -21,22 +21,30 @@ import org.elasticsearch.action.support.WriteRequest;
|
||||||
import org.elasticsearch.common.settings.Settings;
|
import org.elasticsearch.common.settings.Settings;
|
||||||
import org.elasticsearch.common.unit.TimeValue;
|
import org.elasticsearch.common.unit.TimeValue;
|
||||||
import org.elasticsearch.common.xcontent.XContentType;
|
import org.elasticsearch.common.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
import org.elasticsearch.index.query.QueryBuilder;
|
import org.elasticsearch.index.query.QueryBuilder;
|
||||||
import org.elasticsearch.index.query.QueryBuilders;
|
import org.elasticsearch.index.query.QueryBuilders;
|
||||||
import org.elasticsearch.rest.RestStatus;
|
import org.elasticsearch.rest.RestStatus;
|
||||||
import org.elasticsearch.search.SearchHit;
|
import org.elasticsearch.search.SearchHit;
|
||||||
|
import org.elasticsearch.search.SearchModule;
|
||||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||||
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
|
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
|
||||||
|
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
||||||
import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse;
|
import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigUpdate;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigUpdate;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
|
||||||
|
@ -108,6 +116,15 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
.get();
|
.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected NamedXContentRegistry xContentRegistry() {
|
||||||
|
SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList());
|
||||||
|
List<NamedXContentRegistry.Entry> entries = new ArrayList<>(searchModule.getNamedXContents());
|
||||||
|
entries.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||||
|
entries.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
||||||
|
return new NamedXContentRegistry(entries);
|
||||||
|
}
|
||||||
|
|
||||||
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
||||||
initialize("classification_single_numeric_feature_and_mixed_data_set");
|
initialize("classification_single_numeric_feature_and_mixed_data_set");
|
||||||
String predictedClassField = KEYWORD_FIELD + "_prediction";
|
String predictedClassField = KEYWORD_FIELD + "_prediction";
|
||||||
|
@ -121,6 +138,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
null));
|
null));
|
||||||
putAnalytics(config);
|
putAnalytics(config);
|
||||||
|
|
||||||
|
@ -176,6 +194,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
null));
|
null));
|
||||||
putAnalytics(config);
|
putAnalytics(config);
|
||||||
|
|
||||||
|
@ -268,6 +287,76 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
|
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testWithCustomFeatureProcessors() throws Exception {
|
||||||
|
initialize("classification_with_custom_feature_processors");
|
||||||
|
String predictedClassField = KEYWORD_FIELD + "_prediction";
|
||||||
|
indexData(sourceIndex, 300, 50, KEYWORD_FIELD);
|
||||||
|
|
||||||
|
DataFrameAnalyticsConfig config =
|
||||||
|
buildAnalytics(jobId, sourceIndex, destIndex, null,
|
||||||
|
new Classification(
|
||||||
|
KEYWORD_FIELD,
|
||||||
|
BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
Arrays.asList(
|
||||||
|
new OneHotEncoding(TEXT_FIELD, Collections.singletonMap(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom"), true)
|
||||||
|
)));
|
||||||
|
putAnalytics(config);
|
||||||
|
|
||||||
|
assertIsStopped(jobId);
|
||||||
|
assertProgressIsZero(jobId);
|
||||||
|
|
||||||
|
startAnalytics(jobId);
|
||||||
|
waitUntilAnalyticsIsStopped(jobId);
|
||||||
|
|
||||||
|
client().admin().indices().refresh(new RefreshRequest(destIndex));
|
||||||
|
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
||||||
|
for (SearchHit hit : sourceData.getHits()) {
|
||||||
|
Map<String, Object> destDoc = getDestDoc(config, hit);
|
||||||
|
Map<String, Object> resultsObject = getFieldValue(destDoc, "ml");
|
||||||
|
assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES)));
|
||||||
|
assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
|
||||||
|
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
List<Map<String, Object>> importanceArray = (List<Map<String, Object>>)resultsObject.get("feature_importance");
|
||||||
|
assertThat(importanceArray, hasSize(greaterThan(0)));
|
||||||
|
}
|
||||||
|
|
||||||
|
assertProgressComplete(jobId);
|
||||||
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
|
assertModelStatePersisted(stateDocId());
|
||||||
|
assertInferenceModelPersisted(jobId);
|
||||||
|
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
|
||||||
|
assertThatAuditMessagesMatch(jobId,
|
||||||
|
"Created analytics with analysis type [classification]",
|
||||||
|
"Estimated memory usage for this analytics to be",
|
||||||
|
"Starting analytics on node",
|
||||||
|
"Started analytics",
|
||||||
|
expectedDestIndexAuditMessage(),
|
||||||
|
"Started reindexing to destination index [" + destIndex + "]",
|
||||||
|
"Finished reindexing to destination index [" + destIndex + "]",
|
||||||
|
"Started loading data",
|
||||||
|
"Started analyzing",
|
||||||
|
"Started writing results",
|
||||||
|
"Finished analysis");
|
||||||
|
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
|
||||||
|
|
||||||
|
GetTrainedModelsAction.Response response = client().execute(GetTrainedModelsAction.INSTANCE,
|
||||||
|
new GetTrainedModelsAction.Request(jobId + "*", true, Collections.emptyList())).actionGet();
|
||||||
|
assertThat(response.getResources().results().size(), equalTo(1));
|
||||||
|
TrainedModelConfig modelConfig = response.getResources().results().get(0);
|
||||||
|
modelConfig.ensureParsedDefinition(xContentRegistry());
|
||||||
|
assertThat(modelConfig.getModelDefinition().getPreProcessors().size(), greaterThan(0));
|
||||||
|
for (int i = 0; i < modelConfig.getModelDefinition().getPreProcessors().size(); i++) {
|
||||||
|
PreProcessor preProcessor = modelConfig.getModelDefinition().getPreProcessors().get(i);
|
||||||
|
assertThat(preProcessor.isCustom(), equalTo(i == 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(String jobId,
|
public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(String jobId,
|
||||||
String dependentVariable,
|
String dependentVariable,
|
||||||
List<T> dependentVariableValues,
|
List<T> dependentVariableValues,
|
||||||
|
@ -283,7 +372,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
sourceIndex,
|
sourceIndex,
|
||||||
destIndex,
|
destIndex,
|
||||||
null,
|
null,
|
||||||
new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, null, numTopClasses, 50.0, null));
|
new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, null, numTopClasses, 50.0, null, null));
|
||||||
putAnalytics(config);
|
putAnalytics(config);
|
||||||
|
|
||||||
assertIsStopped(jobId);
|
assertIsStopped(jobId);
|
||||||
|
@ -352,7 +441,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
"classification_training_percent_is_50_integer", DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES, "integer");
|
"classification_training_percent_is_50_integer", DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES, "integer");
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsDouble() throws Exception {
|
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsDouble() {
|
||||||
ElasticsearchStatusException e = expectThrows(
|
ElasticsearchStatusException e = expectThrows(
|
||||||
ElasticsearchStatusException.class,
|
ElasticsearchStatusException.class,
|
||||||
() -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
|
() -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
|
||||||
|
@ -360,7 +449,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertThat(e.getMessage(), startsWith("invalid types [double] for required field [numerical-field];"));
|
assertThat(e.getMessage(), startsWith("invalid types [double] for required field [numerical-field];"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsText() throws Exception {
|
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsText() {
|
||||||
ElasticsearchStatusException e = expectThrows(
|
ElasticsearchStatusException e = expectThrows(
|
||||||
ElasticsearchStatusException.class,
|
ElasticsearchStatusException.class,
|
||||||
() -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
|
() -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
|
||||||
|
@ -549,7 +638,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
|
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
|
||||||
new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, null));
|
new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, null, null));
|
||||||
putAnalytics(firstJob);
|
putAnalytics(firstJob);
|
||||||
|
|
||||||
String secondJobId = "classification_two_jobs_with_same_randomize_seed_2";
|
String secondJobId = "classification_two_jobs_with_same_randomize_seed_2";
|
||||||
|
@ -557,7 +646,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
long randomizeSeed = ((Classification) firstJob.getAnalysis()).getRandomizeSeed();
|
long randomizeSeed = ((Classification) firstJob.getAnalysis()).getRandomizeSeed();
|
||||||
DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
|
DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
|
||||||
new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, randomizeSeed));
|
new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, randomizeSeed, null));
|
||||||
|
|
||||||
putAnalytics(secondJob);
|
putAnalytics(secondJob);
|
||||||
|
|
||||||
|
|
|
@ -104,6 +104,7 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
|
||||||
100.0,
|
100.0,
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
null))
|
null))
|
||||||
.buildForExplain();
|
.buildForExplain();
|
||||||
|
|
||||||
|
@ -122,6 +123,7 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
|
||||||
50.0,
|
50.0,
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
null))
|
null))
|
||||||
.buildForExplain();
|
.buildForExplain();
|
||||||
|
|
||||||
|
@ -149,6 +151,7 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
|
||||||
100.0,
|
100.0,
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
null))
|
null))
|
||||||
.buildForExplain();
|
.buildForExplain();
|
||||||
|
|
||||||
|
|
|
@ -14,24 +14,36 @@ import org.elasticsearch.action.get.GetResponse;
|
||||||
import org.elasticsearch.action.index.IndexRequest;
|
import org.elasticsearch.action.index.IndexRequest;
|
||||||
import org.elasticsearch.action.search.SearchResponse;
|
import org.elasticsearch.action.search.SearchResponse;
|
||||||
import org.elasticsearch.action.support.WriteRequest;
|
import org.elasticsearch.action.support.WriteRequest;
|
||||||
|
import org.elasticsearch.common.settings.Settings;
|
||||||
import org.elasticsearch.common.unit.TimeValue;
|
import org.elasticsearch.common.unit.TimeValue;
|
||||||
import org.elasticsearch.common.xcontent.XContentType;
|
import org.elasticsearch.common.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
import org.elasticsearch.rest.RestStatus;
|
import org.elasticsearch.rest.RestStatus;
|
||||||
import org.elasticsearch.search.SearchHit;
|
import org.elasticsearch.search.SearchHit;
|
||||||
|
import org.elasticsearch.search.SearchModule;
|
||||||
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
|
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
|
||||||
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
|
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
|
||||||
|
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
||||||
import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse;
|
import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.time.Instant;
|
import java.time.Instant;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
@ -65,6 +77,15 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
cleanUp();
|
cleanUp();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected NamedXContentRegistry xContentRegistry() {
|
||||||
|
SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList());
|
||||||
|
List<NamedXContentRegistry.Entry> entries = new ArrayList<>(searchModule.getNamedXContents());
|
||||||
|
entries.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||||
|
entries.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
||||||
|
return new NamedXContentRegistry(entries);
|
||||||
|
}
|
||||||
|
|
||||||
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/59413")
|
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/59413")
|
||||||
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
||||||
initialize("regression_single_numeric_feature_and_mixed_data_set");
|
initialize("regression_single_numeric_feature_and_mixed_data_set");
|
||||||
|
@ -79,6 +100,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
null)
|
null)
|
||||||
);
|
);
|
||||||
putAnalytics(config);
|
putAnalytics(config);
|
||||||
|
@ -192,7 +214,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
sourceIndex,
|
sourceIndex,
|
||||||
destIndex,
|
destIndex,
|
||||||
null,
|
null,
|
||||||
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null, null, null));
|
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null, null, null, null));
|
||||||
putAnalytics(config);
|
putAnalytics(config);
|
||||||
|
|
||||||
assertIsStopped(jobId);
|
assertIsStopped(jobId);
|
||||||
|
@ -319,7 +341,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
|
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
|
||||||
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null, null, null));
|
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null, null, null, null));
|
||||||
putAnalytics(firstJob);
|
putAnalytics(firstJob);
|
||||||
|
|
||||||
String secondJobId = "regression_two_jobs_with_same_randomize_seed_2";
|
String secondJobId = "regression_two_jobs_with_same_randomize_seed_2";
|
||||||
|
@ -327,7 +349,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
long randomizeSeed = ((Regression) firstJob.getAnalysis()).getRandomizeSeed();
|
long randomizeSeed = ((Regression) firstJob.getAnalysis()).getRandomizeSeed();
|
||||||
DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
|
DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
|
||||||
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, randomizeSeed, null, null));
|
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, randomizeSeed, null, null, null));
|
||||||
|
|
||||||
putAnalytics(secondJob);
|
putAnalytics(secondJob);
|
||||||
|
|
||||||
|
@ -388,7 +410,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
sourceIndex,
|
sourceIndex,
|
||||||
destIndex,
|
destIndex,
|
||||||
null,
|
null,
|
||||||
new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null, null, null));
|
new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null, null, null, null));
|
||||||
putAnalytics(config);
|
putAnalytics(config);
|
||||||
|
|
||||||
assertIsStopped(jobId);
|
assertIsStopped(jobId);
|
||||||
|
@ -415,6 +437,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
null)
|
null)
|
||||||
);
|
);
|
||||||
putAnalytics(config);
|
putAnalytics(config);
|
||||||
|
@ -511,6 +534,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
90.0,
|
90.0,
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
null);
|
null);
|
||||||
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
|
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
|
||||||
.setId(jobId)
|
.setId(jobId)
|
||||||
|
@ -566,6 +590,73 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
"Finished analysis");
|
"Finished analysis");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testWithCustomFeatureProcessors() throws Exception {
|
||||||
|
initialize("regression_with_custom_feature_processors");
|
||||||
|
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
|
||||||
|
indexData(sourceIndex, 300, 50);
|
||||||
|
|
||||||
|
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null,
|
||||||
|
new Regression(
|
||||||
|
DEPENDENT_VARIABLE_FIELD,
|
||||||
|
BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
Arrays.asList(
|
||||||
|
new OneHotEncoding(DISCRETE_NUMERICAL_FEATURE_FIELD,
|
||||||
|
Collections.singletonMap(DISCRETE_NUMERICAL_FEATURE_VALUES.get(0).toString(), "tenner"), true)
|
||||||
|
))
|
||||||
|
);
|
||||||
|
putAnalytics(config);
|
||||||
|
|
||||||
|
assertIsStopped(jobId);
|
||||||
|
assertProgressIsZero(jobId);
|
||||||
|
|
||||||
|
startAnalytics(jobId);
|
||||||
|
waitUntilAnalyticsIsStopped(jobId);
|
||||||
|
|
||||||
|
// for debugging
|
||||||
|
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
||||||
|
for (SearchHit hit : sourceData.getHits()) {
|
||||||
|
Map<String, Object> destDoc = getDestDoc(config, hit);
|
||||||
|
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
|
||||||
|
|
||||||
|
assertThat(resultsObject.containsKey(predictedClassField), is(true));
|
||||||
|
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||||
|
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
|
||||||
|
}
|
||||||
|
|
||||||
|
assertProgressComplete(jobId);
|
||||||
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
|
assertModelStatePersisted(stateDocId());
|
||||||
|
assertInferenceModelPersisted(jobId);
|
||||||
|
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
|
||||||
|
assertThatAuditMessagesMatch(jobId,
|
||||||
|
"Created analytics with analysis type [regression]",
|
||||||
|
"Estimated memory usage for this analytics to be",
|
||||||
|
"Starting analytics on node",
|
||||||
|
"Started analytics",
|
||||||
|
"Creating destination index [" + destIndex + "]",
|
||||||
|
"Started reindexing to destination index [" + destIndex + "]",
|
||||||
|
"Finished reindexing to destination index [" + destIndex + "]",
|
||||||
|
"Started loading data",
|
||||||
|
"Started analyzing",
|
||||||
|
"Started writing results",
|
||||||
|
"Finished analysis");
|
||||||
|
GetTrainedModelsAction.Response response = client().execute(GetTrainedModelsAction.INSTANCE,
|
||||||
|
new GetTrainedModelsAction.Request(jobId + "*", true, Collections.emptyList())).actionGet();
|
||||||
|
assertThat(response.getResources().results().size(), equalTo(1));
|
||||||
|
TrainedModelConfig modelConfig = response.getResources().results().get(0);
|
||||||
|
modelConfig.ensureParsedDefinition(xContentRegistry());
|
||||||
|
assertThat(modelConfig.getModelDefinition().getPreProcessors().size(), greaterThan(0));
|
||||||
|
for (int i = 0; i < modelConfig.getModelDefinition().getPreProcessors().size(); i++) {
|
||||||
|
PreProcessor preProcessor = modelConfig.getModelDefinition().getPreProcessors().get(i);
|
||||||
|
assertThat(preProcessor.isCustom(), equalTo(i == 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private void initialize(String jobId) {
|
private void initialize(String jobId) {
|
||||||
this.jobId = jobId;
|
this.jobId = jobId;
|
||||||
this.sourceIndex = jobId + "_source_index";
|
this.sourceIndex = jobId + "_source_index";
|
||||||
|
|
|
@ -71,7 +71,7 @@ public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
|
||||||
analyticsConfig,
|
analyticsConfig,
|
||||||
new DataFrameAnalyticsAuditor(client(), "test-node"),
|
new DataFrameAnalyticsAuditor(client(), "test-node"),
|
||||||
(ex) -> { throw new ElasticsearchException(ex); },
|
(ex) -> { throw new ElasticsearchException(ex); },
|
||||||
new ExtractedFields(extractedFieldList, Collections.emptyMap())
|
new ExtractedFields(extractedFieldList, Collections.emptyList(), Collections.emptyMap())
|
||||||
);
|
);
|
||||||
|
|
||||||
//Accuracy for size is not tested here
|
//Accuracy for size is not tested here
|
||||||
|
|
|
@ -27,6 +27,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigUpdate;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||||
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
|
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
|
import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
|
||||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||||
|
@ -171,9 +172,9 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
||||||
blockingCall(
|
blockingCall(
|
||||||
actionListener -> configProvider.put(initialConfig, emptyMap(), actionListener), configHolder, exceptionHolder);
|
actionListener -> configProvider.put(initialConfig, emptyMap(), actionListener), configHolder, exceptionHolder);
|
||||||
|
|
||||||
|
assertNoException(exceptionHolder);
|
||||||
assertThat(configHolder.get(), is(notNullValue()));
|
assertThat(configHolder.get(), is(notNullValue()));
|
||||||
assertThat(configHolder.get(), is(equalTo(initialConfig)));
|
assertThat(configHolder.get(), is(equalTo(initialConfig)));
|
||||||
assertThat(exceptionHolder.get(), is(nullValue()));
|
|
||||||
}
|
}
|
||||||
{ // Update that changes description
|
{ // Update that changes description
|
||||||
AtomicReference<DataFrameAnalyticsConfig> updatedConfigHolder = new AtomicReference<>();
|
AtomicReference<DataFrameAnalyticsConfig> updatedConfigHolder = new AtomicReference<>();
|
||||||
|
@ -188,7 +189,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
||||||
actionListener -> configProvider.update(configUpdate, emptyMap(), ClusterState.EMPTY_STATE, actionListener),
|
actionListener -> configProvider.update(configUpdate, emptyMap(), ClusterState.EMPTY_STATE, actionListener),
|
||||||
updatedConfigHolder,
|
updatedConfigHolder,
|
||||||
exceptionHolder);
|
exceptionHolder);
|
||||||
|
assertNoException(exceptionHolder);
|
||||||
assertThat(updatedConfigHolder.get(), is(notNullValue()));
|
assertThat(updatedConfigHolder.get(), is(notNullValue()));
|
||||||
assertThat(
|
assertThat(
|
||||||
updatedConfigHolder.get(),
|
updatedConfigHolder.get(),
|
||||||
|
@ -196,7 +197,6 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
||||||
new DataFrameAnalyticsConfig.Builder(initialConfig)
|
new DataFrameAnalyticsConfig.Builder(initialConfig)
|
||||||
.setDescription("description-1")
|
.setDescription("description-1")
|
||||||
.build())));
|
.build())));
|
||||||
assertThat(exceptionHolder.get(), is(nullValue()));
|
|
||||||
}
|
}
|
||||||
{ // Update that changes model memory limit
|
{ // Update that changes model memory limit
|
||||||
AtomicReference<DataFrameAnalyticsConfig> updatedConfigHolder = new AtomicReference<>();
|
AtomicReference<DataFrameAnalyticsConfig> updatedConfigHolder = new AtomicReference<>();
|
||||||
|
@ -212,6 +212,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
||||||
updatedConfigHolder,
|
updatedConfigHolder,
|
||||||
exceptionHolder);
|
exceptionHolder);
|
||||||
|
|
||||||
|
assertNoException(exceptionHolder);
|
||||||
assertThat(updatedConfigHolder.get(), is(notNullValue()));
|
assertThat(updatedConfigHolder.get(), is(notNullValue()));
|
||||||
assertThat(
|
assertThat(
|
||||||
updatedConfigHolder.get(),
|
updatedConfigHolder.get(),
|
||||||
|
@ -220,7 +221,6 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
||||||
.setDescription("description-1")
|
.setDescription("description-1")
|
||||||
.setModelMemoryLimit(new ByteSizeValue(1024))
|
.setModelMemoryLimit(new ByteSizeValue(1024))
|
||||||
.build())));
|
.build())));
|
||||||
assertThat(exceptionHolder.get(), is(nullValue()));
|
|
||||||
}
|
}
|
||||||
{ // Noop update
|
{ // Noop update
|
||||||
AtomicReference<DataFrameAnalyticsConfig> updatedConfigHolder = new AtomicReference<>();
|
AtomicReference<DataFrameAnalyticsConfig> updatedConfigHolder = new AtomicReference<>();
|
||||||
|
@ -233,6 +233,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
||||||
updatedConfigHolder,
|
updatedConfigHolder,
|
||||||
exceptionHolder);
|
exceptionHolder);
|
||||||
|
|
||||||
|
assertNoException(exceptionHolder);
|
||||||
assertThat(updatedConfigHolder.get(), is(notNullValue()));
|
assertThat(updatedConfigHolder.get(), is(notNullValue()));
|
||||||
assertThat(
|
assertThat(
|
||||||
updatedConfigHolder.get(),
|
updatedConfigHolder.get(),
|
||||||
|
@ -241,7 +242,6 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
||||||
.setDescription("description-1")
|
.setDescription("description-1")
|
||||||
.setModelMemoryLimit(new ByteSizeValue(1024))
|
.setModelMemoryLimit(new ByteSizeValue(1024))
|
||||||
.build())));
|
.build())));
|
||||||
assertThat(exceptionHolder.get(), is(nullValue()));
|
|
||||||
}
|
}
|
||||||
{ // Update that changes both description and model memory limit
|
{ // Update that changes both description and model memory limit
|
||||||
AtomicReference<DataFrameAnalyticsConfig> updatedConfigHolder = new AtomicReference<>();
|
AtomicReference<DataFrameAnalyticsConfig> updatedConfigHolder = new AtomicReference<>();
|
||||||
|
@ -258,6 +258,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
||||||
updatedConfigHolder,
|
updatedConfigHolder,
|
||||||
exceptionHolder);
|
exceptionHolder);
|
||||||
|
|
||||||
|
assertNoException(exceptionHolder);
|
||||||
assertThat(updatedConfigHolder.get(), is(notNullValue()));
|
assertThat(updatedConfigHolder.get(), is(notNullValue()));
|
||||||
assertThat(
|
assertThat(
|
||||||
updatedConfigHolder.get(),
|
updatedConfigHolder.get(),
|
||||||
|
@ -266,7 +267,6 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
||||||
.setDescription("description-2")
|
.setDescription("description-2")
|
||||||
.setModelMemoryLimit(new ByteSizeValue(2048))
|
.setModelMemoryLimit(new ByteSizeValue(2048))
|
||||||
.build())));
|
.build())));
|
||||||
assertThat(exceptionHolder.get(), is(nullValue()));
|
|
||||||
}
|
}
|
||||||
{ // Update that applies security headers
|
{ // Update that applies security headers
|
||||||
Map<String, String> securityHeaders = Collections.singletonMap("_xpack_security_authentication", "dummy");
|
Map<String, String> securityHeaders = Collections.singletonMap("_xpack_security_authentication", "dummy");
|
||||||
|
@ -281,6 +281,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
||||||
updatedConfigHolder,
|
updatedConfigHolder,
|
||||||
exceptionHolder);
|
exceptionHolder);
|
||||||
|
|
||||||
|
assertNoException(exceptionHolder);
|
||||||
assertThat(updatedConfigHolder.get(), is(notNullValue()));
|
assertThat(updatedConfigHolder.get(), is(notNullValue()));
|
||||||
assertThat(
|
assertThat(
|
||||||
updatedConfigHolder.get(),
|
updatedConfigHolder.get(),
|
||||||
|
@ -290,7 +291,6 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
||||||
.setModelMemoryLimit(new ByteSizeValue(2048))
|
.setModelMemoryLimit(new ByteSizeValue(2048))
|
||||||
.setHeaders(securityHeaders)
|
.setHeaders(securityHeaders)
|
||||||
.build())));
|
.build())));
|
||||||
assertThat(exceptionHolder.get(), is(nullValue()));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -371,6 +371,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
|
||||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||||
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
||||||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||||
|
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||||
return new NamedXContentRegistry(namedXContent);
|
return new NamedXContentRegistry(namedXContent);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,7 +28,9 @@ public class TimeBasedExtractedFields extends ExtractedFields {
|
||||||
private final ExtractedField timeField;
|
private final ExtractedField timeField;
|
||||||
|
|
||||||
public TimeBasedExtractedFields(ExtractedField timeField, List<ExtractedField> allFields) {
|
public TimeBasedExtractedFields(ExtractedField timeField, List<ExtractedField> allFields) {
|
||||||
super(allFields, Collections.emptyMap());
|
super(allFields,
|
||||||
|
Collections.emptyList(),
|
||||||
|
Collections.emptyMap());
|
||||||
if (!allFields.contains(timeField)) {
|
if (!allFields.contains(timeField)) {
|
||||||
throw new IllegalArgumentException("timeField should also be contained in allFields");
|
throw new IllegalArgumentException("timeField should also be contained in allFields");
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,15 +28,18 @@ import org.elasticsearch.search.fetch.StoredFieldsContext;
|
||||||
import org.elasticsearch.search.sort.SortOrder;
|
import org.elasticsearch.search.sort.SortOrder;
|
||||||
import org.elasticsearch.xpack.core.ClientHelper;
|
import org.elasticsearch.xpack.core.ClientHelper;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||||
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.DestinationIndex;
|
import org.elasticsearch.xpack.ml.dataframe.DestinationIndex;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitter;
|
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitter;
|
||||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||||
|
import org.elasticsearch.xpack.ml.extractor.ProcessedField;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.NoSuchElementException;
|
import java.util.NoSuchElementException;
|
||||||
|
@ -46,6 +49,7 @@ import java.util.Set;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
import java.util.function.Supplier;
|
import java.util.function.Supplier;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An implementation that extracts data from elasticsearch using search and scroll on a client.
|
* An implementation that extracts data from elasticsearch using search and scroll on a client.
|
||||||
|
@ -67,10 +71,29 @@ public class DataFrameDataExtractor {
|
||||||
private boolean hasNext;
|
private boolean hasNext;
|
||||||
private boolean searchHasShardFailure;
|
private boolean searchHasShardFailure;
|
||||||
private final CachedSupplier<TrainTestSplitter> trainTestSplitter;
|
private final CachedSupplier<TrainTestSplitter> trainTestSplitter;
|
||||||
|
// These are fields that are sent directly to the analytics process
|
||||||
|
// They are not passed through a feature_processor
|
||||||
|
private final String[] organicFeatures;
|
||||||
|
// These are the output field names for the feature_processors
|
||||||
|
private final String[] processedFeatures;
|
||||||
|
private final Map<String, ExtractedField> extractedFieldsByName;
|
||||||
|
|
||||||
DataFrameDataExtractor(Client client, DataFrameDataExtractorContext context) {
|
DataFrameDataExtractor(Client client, DataFrameDataExtractorContext context) {
|
||||||
this.client = Objects.requireNonNull(client);
|
this.client = Objects.requireNonNull(client);
|
||||||
this.context = Objects.requireNonNull(context);
|
this.context = Objects.requireNonNull(context);
|
||||||
|
Set<String> processedFieldInputs = context.extractedFields.getProcessedFieldInputs();
|
||||||
|
this.organicFeatures = context.extractedFields.getAllFields()
|
||||||
|
.stream()
|
||||||
|
.map(ExtractedField::getName)
|
||||||
|
.filter(f -> processedFieldInputs.contains(f) == false)
|
||||||
|
.toArray(String[]::new);
|
||||||
|
this.processedFeatures = context.extractedFields.getProcessedFields()
|
||||||
|
.stream()
|
||||||
|
.map(ProcessedField::getOutputFieldNames)
|
||||||
|
.flatMap(List::stream)
|
||||||
|
.toArray(String[]::new);
|
||||||
|
this.extractedFieldsByName = new LinkedHashMap<>();
|
||||||
|
context.extractedFields.getAllFields().forEach(f -> this.extractedFieldsByName.put(f.getName(), f));
|
||||||
hasNext = true;
|
hasNext = true;
|
||||||
searchHasShardFailure = false;
|
searchHasShardFailure = false;
|
||||||
this.trainTestSplitter = new CachedSupplier<>(context.trainTestSplitterFactory::create);
|
this.trainTestSplitter = new CachedSupplier<>(context.trainTestSplitterFactory::create);
|
||||||
|
@ -188,26 +211,78 @@ public class DataFrameDataExtractor {
|
||||||
return rows;
|
return rows;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private String extractNonProcessedValues(SearchHit hit, String organicFeature) {
|
||||||
|
ExtractedField field = extractedFieldsByName.get(organicFeature);
|
||||||
|
Object[] values = field.value(hit);
|
||||||
|
if (values.length == 1 && isValidValue(values[0])) {
|
||||||
|
return Objects.toString(values[0]);
|
||||||
|
}
|
||||||
|
if (values.length == 0 && context.supportsRowsWithMissingValues) {
|
||||||
|
// if values is empty then it means it's a missing value
|
||||||
|
return NULL_VALUE;
|
||||||
|
}
|
||||||
|
// we are here if we have a missing value but the analysis does not support those
|
||||||
|
// or the value type is not supported (e.g. arrays, etc.)
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String[] extractProcessedValue(ProcessedField processedField, SearchHit hit) {
|
||||||
|
Object[] values = processedField.value(hit, extractedFieldsByName::get);
|
||||||
|
if (values.length == 0 && context.supportsRowsWithMissingValues == false) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
final String[] extractedValue = new String[processedField.getOutputFieldNames().size()];
|
||||||
|
for (int i = 0; i < processedField.getOutputFieldNames().size(); i++) {
|
||||||
|
extractedValue[i] = NULL_VALUE;
|
||||||
|
}
|
||||||
|
// if values is empty then it means it's a missing value
|
||||||
|
if (values.length == 0) {
|
||||||
|
return extractedValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (values.length != processedField.getOutputFieldNames().size()) {
|
||||||
|
throw ExceptionsHelper.badRequestException(
|
||||||
|
"field_processor [{}] output size expected to be [{}], instead it was [{}]",
|
||||||
|
processedField.getProcessorName(),
|
||||||
|
processedField.getOutputFieldNames().size(),
|
||||||
|
values.length);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < processedField.getOutputFieldNames().size(); ++i) {
|
||||||
|
Object value = values[i];
|
||||||
|
if (value == null && context.supportsRowsWithMissingValues) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (isValidValue(value) == false) {
|
||||||
|
// we are here if we have a missing value but the analysis does not support those
|
||||||
|
// or the value type is not supported (e.g. arrays, etc.)
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
extractedValue[i] = Objects.toString(value);
|
||||||
|
}
|
||||||
|
return extractedValue;
|
||||||
|
}
|
||||||
|
|
||||||
private Row createRow(SearchHit hit) {
|
private Row createRow(SearchHit hit) {
|
||||||
String[] extractedValues = new String[context.extractedFields.getAllFields().size()];
|
String[] extractedValues = new String[organicFeatures.length + processedFeatures.length];
|
||||||
for (int i = 0; i < extractedValues.length; ++i) {
|
int i = 0;
|
||||||
ExtractedField field = context.extractedFields.getAllFields().get(i);
|
for (String organicFeature : organicFeatures) {
|
||||||
Object[] values = field.value(hit);
|
String extractedValue = extractNonProcessedValues(hit, organicFeature);
|
||||||
if (values.length == 1 && (values[0] instanceof Number || values[0] instanceof String)) {
|
if (extractedValue == null) {
|
||||||
extractedValues[i] = Objects.toString(values[0]);
|
return new Row(null, hit, true);
|
||||||
} else {
|
}
|
||||||
if (values.length == 0 && context.supportsRowsWithMissingValues) {
|
extractedValues[i++] = extractedValue;
|
||||||
// if values is empty then it means it's a missing value
|
}
|
||||||
extractedValues[i] = NULL_VALUE;
|
for (ProcessedField processedField : context.extractedFields.getProcessedFields()) {
|
||||||
} else {
|
String[] processedValues = extractProcessedValue(processedField, hit);
|
||||||
// we are here if we have a missing value but the analysis does not support those
|
if (processedValues == null) {
|
||||||
// or the value type is not supported (e.g. arrays, etc.)
|
return new Row(null, hit, true);
|
||||||
extractedValues = null;
|
}
|
||||||
break;
|
for (String processedValue : processedValues) {
|
||||||
}
|
extractedValues[i++] = processedValue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
boolean isTraining = extractedValues == null ? false : trainTestSplitter.get().isTraining(extractedValues);
|
boolean isTraining = trainTestSplitter.get().isTraining(extractedValues);
|
||||||
return new Row(extractedValues, hit, isTraining);
|
return new Row(extractedValues, hit, isTraining);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -241,7 +316,7 @@ public class DataFrameDataExtractor {
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<String> getFieldNames() {
|
public List<String> getFieldNames() {
|
||||||
return context.extractedFields.getAllFields().stream().map(ExtractedField::getName).collect(Collectors.toList());
|
return Stream.concat(Arrays.stream(organicFeatures), Arrays.stream(processedFeatures)).collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
public ExtractedFields getExtractedFields() {
|
public ExtractedFields getExtractedFields() {
|
||||||
|
@ -253,12 +328,12 @@ public class DataFrameDataExtractor {
|
||||||
SearchResponse searchResponse = executeSearchRequest(searchRequestBuilder);
|
SearchResponse searchResponse = executeSearchRequest(searchRequestBuilder);
|
||||||
long rows = searchResponse.getHits().getTotalHits().value;
|
long rows = searchResponse.getHits().getTotalHits().value;
|
||||||
LOGGER.debug("[{}] Data summary rows [{}]", context.jobId, rows);
|
LOGGER.debug("[{}] Data summary rows [{}]", context.jobId, rows);
|
||||||
return new DataSummary(rows, context.extractedFields.getAllFields().size());
|
return new DataSummary(rows, organicFeatures.length + processedFeatures.length);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void collectDataSummaryAsync(ActionListener<DataSummary> dataSummaryActionListener) {
|
public void collectDataSummaryAsync(ActionListener<DataSummary> dataSummaryActionListener) {
|
||||||
SearchRequestBuilder searchRequestBuilder = buildDataSummarySearchRequestBuilder();
|
SearchRequestBuilder searchRequestBuilder = buildDataSummarySearchRequestBuilder();
|
||||||
final int numberOfFields = context.extractedFields.getAllFields().size();
|
final int numberOfFields = organicFeatures.length + processedFeatures.length;
|
||||||
|
|
||||||
ClientHelper.executeWithHeadersAsync(context.headers,
|
ClientHelper.executeWithHeadersAsync(context.headers,
|
||||||
ClientHelper.ML_ORIGIN,
|
ClientHelper.ML_ORIGIN,
|
||||||
|
@ -298,7 +373,11 @@ public class DataFrameDataExtractor {
|
||||||
}
|
}
|
||||||
|
|
||||||
public Set<String> getCategoricalFields(DataFrameAnalysis analysis) {
|
public Set<String> getCategoricalFields(DataFrameAnalysis analysis) {
|
||||||
return ExtractedFieldsDetector.getCategoricalFields(context.extractedFields, analysis);
|
return ExtractedFieldsDetector.getCategoricalOutputFields(context.extractedFields, analysis);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static boolean isValidValue(Object value) {
|
||||||
|
return value instanceof Number || value instanceof String;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class DataSummary {
|
public static class DataSummary {
|
||||||
|
|
|
@ -13,27 +13,33 @@ import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
|
||||||
import org.elasticsearch.common.Strings;
|
import org.elasticsearch.common.Strings;
|
||||||
import org.elasticsearch.common.collect.Tuple;
|
import org.elasticsearch.common.collect.Tuple;
|
||||||
import org.elasticsearch.common.regex.Regex;
|
import org.elasticsearch.common.regex.Regex;
|
||||||
|
import org.elasticsearch.common.util.set.Sets;
|
||||||
import org.elasticsearch.index.IndexSettings;
|
import org.elasticsearch.index.IndexSettings;
|
||||||
import org.elasticsearch.index.mapper.BooleanFieldMapper;
|
import org.elasticsearch.index.mapper.BooleanFieldMapper;
|
||||||
import org.elasticsearch.index.mapper.ObjectMapper;
|
import org.elasticsearch.index.mapper.ObjectMapper;
|
||||||
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
|
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.FieldCardinalityConstraint;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.FieldCardinalityConstraint;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.RequiredField;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.RequiredField;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Types;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Types;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection;
|
import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.NameResolver;
|
import org.elasticsearch.xpack.core.ml.utils.NameResolver;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.DestinationIndex;
|
import org.elasticsearch.xpack.ml.dataframe.DestinationIndex;
|
||||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||||
|
import org.elasticsearch.xpack.ml.extractor.ProcessedField;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -60,7 +66,9 @@ public class ExtractedFieldsDetector {
|
||||||
private final FieldCapabilitiesResponse fieldCapabilitiesResponse;
|
private final FieldCapabilitiesResponse fieldCapabilitiesResponse;
|
||||||
private final Map<String, Long> cardinalitiesForFieldsWithConstraints;
|
private final Map<String, Long> cardinalitiesForFieldsWithConstraints;
|
||||||
|
|
||||||
ExtractedFieldsDetector(DataFrameAnalyticsConfig config, int docValueFieldsLimit, FieldCapabilitiesResponse fieldCapabilitiesResponse,
|
ExtractedFieldsDetector(DataFrameAnalyticsConfig config,
|
||||||
|
int docValueFieldsLimit,
|
||||||
|
FieldCapabilitiesResponse fieldCapabilitiesResponse,
|
||||||
Map<String, Long> cardinalitiesForFieldsWithConstraints) {
|
Map<String, Long> cardinalitiesForFieldsWithConstraints) {
|
||||||
this.config = Objects.requireNonNull(config);
|
this.config = Objects.requireNonNull(config);
|
||||||
this.docValueFieldsLimit = docValueFieldsLimit;
|
this.docValueFieldsLimit = docValueFieldsLimit;
|
||||||
|
@ -69,23 +77,39 @@ public class ExtractedFieldsDetector {
|
||||||
}
|
}
|
||||||
|
|
||||||
public Tuple<ExtractedFields, List<FieldSelection>> detect() {
|
public Tuple<ExtractedFields, List<FieldSelection>> detect() {
|
||||||
|
List<ProcessedField> processedFields = extractFeatureProcessors()
|
||||||
|
.stream()
|
||||||
|
.map(ProcessedField::new)
|
||||||
|
.collect(Collectors.toList());
|
||||||
TreeSet<FieldSelection> fieldSelection = new TreeSet<>(Comparator.comparing(FieldSelection::getName));
|
TreeSet<FieldSelection> fieldSelection = new TreeSet<>(Comparator.comparing(FieldSelection::getName));
|
||||||
Set<String> fields = getIncludedFields(fieldSelection);
|
Set<String> fields = getIncludedFields(fieldSelection,
|
||||||
|
processedFields.stream()
|
||||||
|
.map(ProcessedField::getInputFieldNames)
|
||||||
|
.flatMap(List::stream)
|
||||||
|
.collect(Collectors.toSet()));
|
||||||
checkFieldsHaveCompatibleTypes(fields);
|
checkFieldsHaveCompatibleTypes(fields);
|
||||||
checkRequiredFields(fields);
|
checkRequiredFields(fields);
|
||||||
checkFieldsWithCardinalityLimit();
|
checkFieldsWithCardinalityLimit();
|
||||||
ExtractedFields extractedFields = detectExtractedFields(fields, fieldSelection);
|
ExtractedFields extractedFields = detectExtractedFields(fields, fieldSelection, processedFields);
|
||||||
addIncludedFields(extractedFields, fieldSelection);
|
addIncludedFields(extractedFields, fieldSelection);
|
||||||
|
|
||||||
|
checkOutputFeatureUniqueness(processedFields, fields);
|
||||||
|
|
||||||
return Tuple.tuple(extractedFields, Collections.unmodifiableList(new ArrayList<>(fieldSelection)));
|
return Tuple.tuple(extractedFields, Collections.unmodifiableList(new ArrayList<>(fieldSelection)));
|
||||||
}
|
}
|
||||||
|
|
||||||
private Set<String> getIncludedFields(Set<FieldSelection> fieldSelection) {
|
private Set<String> getIncludedFields(Set<FieldSelection> fieldSelection, Set<String> requiredFieldsForProcessors) {
|
||||||
Set<String> fields = new TreeSet<>(fieldCapabilitiesResponse.get().keySet());
|
Set<String> fields = new TreeSet<>(fieldCapabilitiesResponse.get().keySet());
|
||||||
|
validateFieldsRequireForProcessors(requiredFieldsForProcessors);
|
||||||
fields.removeAll(IGNORE_FIELDS);
|
fields.removeAll(IGNORE_FIELDS);
|
||||||
removeFieldsUnderResultsField(fields);
|
removeFieldsUnderResultsField(fields);
|
||||||
removeObjects(fields);
|
removeObjects(fields);
|
||||||
applySourceFiltering(fields);
|
applySourceFiltering(fields);
|
||||||
|
if (fields.containsAll(requiredFieldsForProcessors) == false) {
|
||||||
|
throw ExceptionsHelper.badRequestException(
|
||||||
|
"fields {} required by field_processors are not included in source filtering.",
|
||||||
|
Sets.difference(requiredFieldsForProcessors, fields));
|
||||||
|
}
|
||||||
FetchSourceContext analyzedFields = config.getAnalyzedFields();
|
FetchSourceContext analyzedFields = config.getAnalyzedFields();
|
||||||
|
|
||||||
// If the user has not explicitly included fields we'll include all compatible fields
|
// If the user has not explicitly included fields we'll include all compatible fields
|
||||||
|
@ -93,20 +117,63 @@ public class ExtractedFieldsDetector {
|
||||||
removeFieldsWithIncompatibleTypes(fields, fieldSelection);
|
removeFieldsWithIncompatibleTypes(fields, fieldSelection);
|
||||||
}
|
}
|
||||||
includeAndExcludeFields(fields, fieldSelection);
|
includeAndExcludeFields(fields, fieldSelection);
|
||||||
|
if (fields.containsAll(requiredFieldsForProcessors) == false) {
|
||||||
|
throw ExceptionsHelper.badRequestException(
|
||||||
|
"fields {} required by field_processors are not included in the analyzed_fields.",
|
||||||
|
Sets.difference(requiredFieldsForProcessors, fields));
|
||||||
|
}
|
||||||
|
|
||||||
return fields;
|
return fields;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void validateFieldsRequireForProcessors(Set<String> processorFields) {
|
||||||
|
Set<String> fieldsForProcessor = new HashSet<>(processorFields);
|
||||||
|
removeFieldsUnderResultsField(fieldsForProcessor);
|
||||||
|
if (fieldsForProcessor.size() < processorFields.size()) {
|
||||||
|
throw ExceptionsHelper.badRequestException("fields contained in results field [{}] cannot be used in a feature_processor",
|
||||||
|
config.getDest().getResultsField());
|
||||||
|
}
|
||||||
|
removeObjects(fieldsForProcessor);
|
||||||
|
if (fieldsForProcessor.size() < processorFields.size()) {
|
||||||
|
throw ExceptionsHelper.badRequestException("fields for feature_processors must not be objects");
|
||||||
|
}
|
||||||
|
fieldsForProcessor.removeAll(IGNORE_FIELDS);
|
||||||
|
if (fieldsForProcessor.size() < processorFields.size()) {
|
||||||
|
throw ExceptionsHelper.badRequestException("the following fields cannot be used in feature_processors {}", IGNORE_FIELDS);
|
||||||
|
}
|
||||||
|
List<String> fieldsMissingInMapping = processorFields.stream()
|
||||||
|
.filter(f -> fieldCapabilitiesResponse.get().containsKey(f) == false)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
if (fieldsMissingInMapping.isEmpty() == false) {
|
||||||
|
throw ExceptionsHelper.badRequestException(
|
||||||
|
"the fields {} were not found in the field capabilities of the source indices [{}]. "
|
||||||
|
+ "Fields must exist and be mapped to be used in feature_processors.",
|
||||||
|
fieldsMissingInMapping,
|
||||||
|
Strings.arrayToCommaDelimitedString(config.getSource().getIndex()));
|
||||||
|
}
|
||||||
|
List<String> processedRequiredFields = config.getAnalysis()
|
||||||
|
.getRequiredFields()
|
||||||
|
.stream()
|
||||||
|
.map(RequiredField::getName)
|
||||||
|
.filter(processorFields::contains)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
if (processedRequiredFields.isEmpty() == false) {
|
||||||
|
throw ExceptionsHelper.badRequestException(
|
||||||
|
"required analysis fields {} cannot be used in a feature_processor",
|
||||||
|
processedRequiredFields);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private void removeFieldsUnderResultsField(Set<String> fields) {
|
private void removeFieldsUnderResultsField(Set<String> fields) {
|
||||||
String resultsField = config.getDest().getResultsField();
|
final String resultsFieldPrefix = config.getDest().getResultsField() + ".";
|
||||||
Iterator<String> fieldsIterator = fields.iterator();
|
Iterator<String> fieldsIterator = fields.iterator();
|
||||||
while (fieldsIterator.hasNext()) {
|
while (fieldsIterator.hasNext()) {
|
||||||
String field = fieldsIterator.next();
|
String field = fieldsIterator.next();
|
||||||
if (field.startsWith(resultsField + ".")) {
|
if (field.startsWith(resultsFieldPrefix)) {
|
||||||
fieldsIterator.remove();
|
fieldsIterator.remove();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fields.removeIf(field -> field.startsWith(resultsField + "."));
|
fields.removeIf(field -> field.startsWith(resultsFieldPrefix));
|
||||||
}
|
}
|
||||||
|
|
||||||
private void removeObjects(Set<String> fields) {
|
private void removeObjects(Set<String> fields) {
|
||||||
|
@ -287,9 +354,23 @@ public class ExtractedFieldsDetector {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private ExtractedFields detectExtractedFields(Set<String> fields, Set<FieldSelection> fieldSelection) {
|
private List<PreProcessor> extractFeatureProcessors() {
|
||||||
ExtractedFields extractedFields = ExtractedFields.build(fields, Collections.emptySet(), fieldCapabilitiesResponse,
|
if (config.getAnalysis() instanceof Classification) {
|
||||||
cardinalitiesForFieldsWithConstraints);
|
return ((Classification)config.getAnalysis()).getFeatureProcessors();
|
||||||
|
} else if (config.getAnalysis() instanceof Regression) {
|
||||||
|
return ((Regression)config.getAnalysis()).getFeatureProcessors();
|
||||||
|
}
|
||||||
|
return Collections.emptyList();
|
||||||
|
}
|
||||||
|
|
||||||
|
private ExtractedFields detectExtractedFields(Set<String> fields,
|
||||||
|
Set<FieldSelection> fieldSelection,
|
||||||
|
List<ProcessedField> processedFields) {
|
||||||
|
ExtractedFields extractedFields = ExtractedFields.build(fields,
|
||||||
|
Collections.emptySet(),
|
||||||
|
fieldCapabilitiesResponse,
|
||||||
|
cardinalitiesForFieldsWithConstraints,
|
||||||
|
processedFields);
|
||||||
boolean preferSource = extractedFields.getDocValueFields().size() > docValueFieldsLimit;
|
boolean preferSource = extractedFields.getDocValueFields().size() > docValueFieldsLimit;
|
||||||
extractedFields = deduplicateMultiFields(extractedFields, preferSource, fieldSelection);
|
extractedFields = deduplicateMultiFields(extractedFields, preferSource, fieldSelection);
|
||||||
if (preferSource) {
|
if (preferSource) {
|
||||||
|
@ -304,10 +385,15 @@ public class ExtractedFieldsDetector {
|
||||||
return extractedFields;
|
return extractedFields;
|
||||||
}
|
}
|
||||||
|
|
||||||
private ExtractedFields deduplicateMultiFields(ExtractedFields extractedFields, boolean preferSource,
|
private ExtractedFields deduplicateMultiFields(ExtractedFields extractedFields,
|
||||||
|
boolean preferSource,
|
||||||
Set<FieldSelection> fieldSelection) {
|
Set<FieldSelection> fieldSelection) {
|
||||||
Set<String> requiredFields = config.getAnalysis().getRequiredFields().stream().map(RequiredField::getName)
|
Set<String> requiredFields = config.getAnalysis()
|
||||||
|
.getRequiredFields()
|
||||||
|
.stream()
|
||||||
|
.map(RequiredField::getName)
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
|
Set<String> processorInputFields = extractedFields.getProcessedFieldInputs();
|
||||||
Map<String, ExtractedField> nameOrParentToField = new LinkedHashMap<>();
|
Map<String, ExtractedField> nameOrParentToField = new LinkedHashMap<>();
|
||||||
for (ExtractedField currentField : extractedFields.getAllFields()) {
|
for (ExtractedField currentField : extractedFields.getAllFields()) {
|
||||||
String nameOrParent = currentField.isMultiField() ? currentField.getParentField() : currentField.getName();
|
String nameOrParent = currentField.isMultiField() ? currentField.getParentField() : currentField.getName();
|
||||||
|
@ -315,15 +401,37 @@ public class ExtractedFieldsDetector {
|
||||||
if (existingField != null) {
|
if (existingField != null) {
|
||||||
ExtractedField parent = currentField.isMultiField() ? existingField : currentField;
|
ExtractedField parent = currentField.isMultiField() ? existingField : currentField;
|
||||||
ExtractedField multiField = currentField.isMultiField() ? currentField : existingField;
|
ExtractedField multiField = currentField.isMultiField() ? currentField : existingField;
|
||||||
|
// If required fields contains parent or multifield and the processor input fields reference the other, that is an error
|
||||||
|
// we should not allow processing of data that is required.
|
||||||
|
if ((requiredFields.contains(parent.getName()) && processorInputFields.contains(multiField.getName()))
|
||||||
|
|| (requiredFields.contains(multiField.getName()) && processorInputFields.contains(parent.getName()))) {
|
||||||
|
throw ExceptionsHelper.badRequestException(
|
||||||
|
"feature_processors cannot be applied to required fields for analysis; multi-field [{}] parent [{}]",
|
||||||
|
multiField.getName(),
|
||||||
|
parent.getName());
|
||||||
|
}
|
||||||
|
// If processor input fields have BOTH, we need to keep both.
|
||||||
|
if (processorInputFields.contains(parent.getName()) && processorInputFields.contains(multiField.getName())) {
|
||||||
|
throw ExceptionsHelper.badRequestException(
|
||||||
|
"feature_processors refer to both multi-field [{}] and parent [{}]. Please only refer to one or the other",
|
||||||
|
multiField.getName(),
|
||||||
|
parent.getName());
|
||||||
|
}
|
||||||
nameOrParentToField.put(nameOrParent,
|
nameOrParentToField.put(nameOrParent,
|
||||||
chooseMultiFieldOrParent(preferSource, requiredFields, parent, multiField, fieldSelection));
|
chooseMultiFieldOrParent(preferSource, requiredFields, processorInputFields, parent, multiField, fieldSelection));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()), cardinalitiesForFieldsWithConstraints);
|
return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()),
|
||||||
|
extractedFields.getProcessedFields(),
|
||||||
|
cardinalitiesForFieldsWithConstraints);
|
||||||
}
|
}
|
||||||
|
|
||||||
private ExtractedField chooseMultiFieldOrParent(boolean preferSource, Set<String> requiredFields, ExtractedField parent,
|
private ExtractedField chooseMultiFieldOrParent(boolean preferSource,
|
||||||
ExtractedField multiField, Set<FieldSelection> fieldSelection) {
|
Set<String> requiredFields,
|
||||||
|
Set<String> processorInputFields,
|
||||||
|
ExtractedField parent,
|
||||||
|
ExtractedField multiField,
|
||||||
|
Set<FieldSelection> fieldSelection) {
|
||||||
// Check requirements first
|
// Check requirements first
|
||||||
if (requiredFields.contains(parent.getName())) {
|
if (requiredFields.contains(parent.getName())) {
|
||||||
addExcludedField(multiField.getName(), "[" + parent.getName() + "] is required instead", fieldSelection);
|
addExcludedField(multiField.getName(), "[" + parent.getName() + "] is required instead", fieldSelection);
|
||||||
|
@ -333,6 +441,19 @@ public class ExtractedFieldsDetector {
|
||||||
addExcludedField(parent.getName(), "[" + multiField.getName() + "] is required instead", fieldSelection);
|
addExcludedField(parent.getName(), "[" + multiField.getName() + "] is required instead", fieldSelection);
|
||||||
return multiField;
|
return multiField;
|
||||||
}
|
}
|
||||||
|
// Choose the one required by our processors
|
||||||
|
if (processorInputFields.contains(parent.getName())) {
|
||||||
|
addExcludedField(multiField.getName(),
|
||||||
|
"[" + parent.getName() + "] is referenced by feature_processors instead",
|
||||||
|
fieldSelection);
|
||||||
|
return parent;
|
||||||
|
}
|
||||||
|
if (processorInputFields.contains(multiField.getName())) {
|
||||||
|
addExcludedField(parent.getName(),
|
||||||
|
"[" + multiField.getName() + "] is referenced by feature_processors instead",
|
||||||
|
fieldSelection);
|
||||||
|
return multiField;
|
||||||
|
}
|
||||||
|
|
||||||
// If both are multi-fields it means there are several. In this case parent is the previous multi-field
|
// If both are multi-fields it means there are several. In this case parent is the previous multi-field
|
||||||
// we selected. We'll just keep that.
|
// we selected. We'll just keep that.
|
||||||
|
@ -370,7 +491,9 @@ public class ExtractedFieldsDetector {
|
||||||
for (ExtractedField field : extractedFields.getAllFields()) {
|
for (ExtractedField field : extractedFields.getAllFields()) {
|
||||||
adjusted.add(field.supportsFromSource() ? field.newFromSource() : field);
|
adjusted.add(field.supportsFromSource() ? field.newFromSource() : field);
|
||||||
}
|
}
|
||||||
return new ExtractedFields(adjusted, cardinalitiesForFieldsWithConstraints);
|
return new ExtractedFields(adjusted,
|
||||||
|
extractedFields.getProcessedFields(),
|
||||||
|
cardinalitiesForFieldsWithConstraints);
|
||||||
}
|
}
|
||||||
|
|
||||||
private ExtractedFields fetchBooleanFieldsAsIntegers(ExtractedFields extractedFields) {
|
private ExtractedFields fetchBooleanFieldsAsIntegers(ExtractedFields extractedFields) {
|
||||||
|
@ -387,13 +510,15 @@ public class ExtractedFieldsDetector {
|
||||||
adjusted.add(field);
|
adjusted.add(field);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return new ExtractedFields(adjusted, cardinalitiesForFieldsWithConstraints);
|
return new ExtractedFields(adjusted,
|
||||||
|
extractedFields.getProcessedFields(),
|
||||||
|
cardinalitiesForFieldsWithConstraints);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addIncludedFields(ExtractedFields extractedFields, Set<FieldSelection> fieldSelection) {
|
private void addIncludedFields(ExtractedFields extractedFields, Set<FieldSelection> fieldSelection) {
|
||||||
Set<String> requiredFields = config.getAnalysis().getRequiredFields().stream().map(RequiredField::getName)
|
Set<String> requiredFields = config.getAnalysis().getRequiredFields().stream().map(RequiredField::getName)
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
Set<String> categoricalFields = getCategoricalFields(extractedFields, config.getAnalysis());
|
Set<String> categoricalFields = getCategoricalInputFields(extractedFields, config.getAnalysis());
|
||||||
for (ExtractedField includedField : extractedFields.getAllFields()) {
|
for (ExtractedField includedField : extractedFields.getAllFields()) {
|
||||||
FieldSelection.FeatureType featureType = categoricalFields.contains(includedField.getName()) ?
|
FieldSelection.FeatureType featureType = categoricalFields.contains(includedField.getName()) ?
|
||||||
FieldSelection.FeatureType.CATEGORICAL : FieldSelection.FeatureType.NUMERICAL;
|
FieldSelection.FeatureType.CATEGORICAL : FieldSelection.FeatureType.NUMERICAL;
|
||||||
|
@ -402,7 +527,38 @@ public class ExtractedFieldsDetector {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static Set<String> getCategoricalFields(ExtractedFields extractedFields, DataFrameAnalysis analysis) {
|
static void checkOutputFeatureUniqueness(List<ProcessedField> processedFields, Set<String> selectedFields) {
|
||||||
|
Set<String> processInputs = processedFields.stream()
|
||||||
|
.map(ProcessedField::getInputFieldNames)
|
||||||
|
.flatMap(List::stream)
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
// All analysis fields that we include that are NOT processed
|
||||||
|
// This indicates that they are sent as is
|
||||||
|
Set<String> organicFields = Sets.difference(selectedFields, processInputs);
|
||||||
|
|
||||||
|
Set<String> processedFeatures = new HashSet<>();
|
||||||
|
Set<String> duplicatedFields = new HashSet<>();
|
||||||
|
for (ProcessedField processedField : processedFields) {
|
||||||
|
for (String output : processedField.getOutputFieldNames()) {
|
||||||
|
if (processedFeatures.add(output) == false) {
|
||||||
|
duplicatedFields.add(output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (duplicatedFields.isEmpty() == false) {
|
||||||
|
throw ExceptionsHelper.badRequestException(
|
||||||
|
"feature_processors must define unique output field names; duplicate fields {}",
|
||||||
|
duplicatedFields);
|
||||||
|
}
|
||||||
|
Set<String> duplicateOrganicAndProcessed = Sets.intersection(organicFields, processedFeatures);
|
||||||
|
if (duplicateOrganicAndProcessed.isEmpty() == false) {
|
||||||
|
throw ExceptionsHelper.badRequestException(
|
||||||
|
"feature_processors output fields must not include non-processed analysis fields; duplicate fields {}",
|
||||||
|
duplicateOrganicAndProcessed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static Set<String> getCategoricalInputFields(ExtractedFields extractedFields, DataFrameAnalysis analysis) {
|
||||||
return extractedFields.getAllFields().stream()
|
return extractedFields.getAllFields().stream()
|
||||||
.filter(extractedField -> analysis.getAllowedCategoricalTypes(extractedField.getName())
|
.filter(extractedField -> analysis.getAllowedCategoricalTypes(extractedField.getName())
|
||||||
.containsAll(extractedField.getTypes()))
|
.containsAll(extractedField.getTypes()))
|
||||||
|
@ -410,6 +566,25 @@ public class ExtractedFieldsDetector {
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Set<String> getCategoricalOutputFields(ExtractedFields extractedFields, DataFrameAnalysis analysis) {
|
||||||
|
Set<String> processInputFields = extractedFields.getProcessedFieldInputs();
|
||||||
|
Set<String> categoricalFields = extractedFields.getAllFields().stream()
|
||||||
|
.filter(extractedField -> analysis.getAllowedCategoricalTypes(extractedField.getName())
|
||||||
|
.containsAll(extractedField.getTypes()))
|
||||||
|
.map(ExtractedField::getName)
|
||||||
|
.filter(name -> processInputFields.contains(name) == false)
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
|
||||||
|
extractedFields.getProcessedFields().forEach(processedField ->
|
||||||
|
processedField.getOutputFieldNames().forEach(outputField -> {
|
||||||
|
if (analysis.getAllowedCategoricalTypes(outputField).containsAll(processedField.getOutputFieldType(outputField))) {
|
||||||
|
categoricalFields.add(outputField);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
);
|
||||||
|
return Collections.unmodifiableSet(categoricalFields);
|
||||||
|
}
|
||||||
|
|
||||||
private static boolean isBoolean(Set<String> types) {
|
private static boolean isBoolean(Set<String> types) {
|
||||||
return types.size() == 1 && types.contains(BooleanFieldMapper.CONTENT_TYPE);
|
return types.size() == 1 && types.contains(BooleanFieldMapper.CONTENT_TYPE);
|
||||||
}
|
}
|
||||||
|
|
|
@ -178,7 +178,7 @@ public class AnalyticsProcessManager {
|
||||||
AnalyticsProcess<AnalyticsResult> process = processContext.process.get();
|
AnalyticsProcess<AnalyticsResult> process = processContext.process.get();
|
||||||
AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get();
|
AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get();
|
||||||
try {
|
try {
|
||||||
writeHeaderRecord(dataExtractor, process);
|
writeHeaderRecord(dataExtractor, process, task);
|
||||||
writeDataRows(dataExtractor, process, task);
|
writeDataRows(dataExtractor, process, task);
|
||||||
process.writeEndOfDataMessage();
|
process.writeEndOfDataMessage();
|
||||||
process.flushStream();
|
process.flushStream();
|
||||||
|
@ -268,8 +268,11 @@ public class AnalyticsProcessManager {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void writeHeaderRecord(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process) throws IOException {
|
private void writeHeaderRecord(DataFrameDataExtractor dataExtractor,
|
||||||
|
AnalyticsProcess<AnalyticsResult> process,
|
||||||
|
DataFrameAnalyticsTask task) throws IOException {
|
||||||
List<String> fieldNames = dataExtractor.getFieldNames();
|
List<String> fieldNames = dataExtractor.getFieldNames();
|
||||||
|
LOGGER.debug(() -> new ParameterizedMessage("[{}] header row fields {}", task.getParams().getId(), fieldNames));
|
||||||
|
|
||||||
// We add 2 extra fields, both named dot:
|
// We add 2 extra fields, both named dot:
|
||||||
// - the document hash
|
// - the document hash
|
||||||
|
|
|
@ -9,6 +9,7 @@ package org.elasticsearch.xpack.ml.dataframe.process;
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||||
|
import org.apache.lucene.util.RamUsageEstimator;
|
||||||
import org.elasticsearch.Version;
|
import org.elasticsearch.Version;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
import org.elasticsearch.action.LatchedActionListener;
|
import org.elasticsearch.action.LatchedActionListener;
|
||||||
|
@ -22,6 +23,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
import org.elasticsearch.xpack.core.security.user.XPackUser;
|
import org.elasticsearch.xpack.core.security.user.XPackUser;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
|
import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
|
||||||
|
@ -34,6 +36,7 @@ import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||||
|
|
||||||
import java.time.Instant;
|
import java.time.Instant;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
@ -191,8 +194,21 @@ public class ChunkedTrainedModelPersister {
|
||||||
return latch;
|
return latch;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private long customProcessorSize() {
|
||||||
|
List<PreProcessor> preProcessors = new ArrayList<>();
|
||||||
|
if (analytics.getAnalysis() instanceof Classification) {
|
||||||
|
preProcessors = ((Classification) analytics.getAnalysis()).getFeatureProcessors();
|
||||||
|
} else if (analytics.getAnalysis() instanceof Regression) {
|
||||||
|
preProcessors = ((Regression) analytics.getAnalysis()).getFeatureProcessors();
|
||||||
|
}
|
||||||
|
return preProcessors.stream().mapToLong(PreProcessor::ramBytesUsed).sum()
|
||||||
|
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * preProcessors.size();
|
||||||
|
}
|
||||||
|
|
||||||
private TrainedModelConfig createTrainedModelConfig(ModelSizeInfo modelSize) {
|
private TrainedModelConfig createTrainedModelConfig(ModelSizeInfo modelSize) {
|
||||||
Instant createTime = Instant.now();
|
Instant createTime = Instant.now();
|
||||||
|
// The native process does not provide estimates for the custom feature_processor objects
|
||||||
|
long customProcessorSize = customProcessorSize();
|
||||||
String modelId = analytics.getId() + "-" + createTime.toEpochMilli();
|
String modelId = analytics.getId() + "-" + createTime.toEpochMilli();
|
||||||
currentModelId.set(modelId);
|
currentModelId.set(modelId);
|
||||||
List<ExtractedField> fieldNames = extractedFields.getAllFields();
|
List<ExtractedField> fieldNames = extractedFields.getAllFields();
|
||||||
|
@ -214,7 +230,7 @@ public class ChunkedTrainedModelPersister {
|
||||||
.setDescription(analytics.getDescription())
|
.setDescription(analytics.getDescription())
|
||||||
.setMetadata(Collections.singletonMap("analytics_config",
|
.setMetadata(Collections.singletonMap("analytics_config",
|
||||||
XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true)))
|
XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true)))
|
||||||
.setEstimatedHeapMemory(modelSize.ramBytesUsed())
|
.setEstimatedHeapMemory(modelSize.ramBytesUsed() + customProcessorSize)
|
||||||
.setEstimatedOperations(modelSize.numOperations())
|
.setEstimatedOperations(modelSize.numOperations())
|
||||||
.setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable))
|
.setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable))
|
||||||
.setLicenseLevel(License.OperationMode.PLATINUM.description())
|
.setLicenseLevel(License.OperationMode.PLATINUM.description())
|
||||||
|
|
|
@ -12,7 +12,7 @@ import org.elasticsearch.index.mapper.BooleanFieldMapper;
|
||||||
import org.elasticsearch.search.SearchHit;
|
import org.elasticsearch.search.SearchHit;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
|
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
|
||||||
|
|
||||||
import java.util.Collection;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
@ -21,27 +21,39 @@ import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The fields the datafeed has to extract
|
* The fields the data[feed|frame] has to extract
|
||||||
*/
|
*/
|
||||||
public class ExtractedFields {
|
public class ExtractedFields {
|
||||||
|
|
||||||
private final List<ExtractedField> allFields;
|
private final List<ExtractedField> allFields;
|
||||||
private final List<ExtractedField> docValueFields;
|
private final List<ExtractedField> docValueFields;
|
||||||
|
private final List<ProcessedField> processedFields;
|
||||||
private final String[] sourceFields;
|
private final String[] sourceFields;
|
||||||
private final Map<String, Long> cardinalitiesForFieldsWithConstraints;
|
private final Map<String, Long> cardinalitiesForFieldsWithConstraints;
|
||||||
|
|
||||||
public ExtractedFields(List<ExtractedField> allFields, Map<String, Long> cardinalitiesForFieldsWithConstraints) {
|
public ExtractedFields(List<ExtractedField> allFields,
|
||||||
this.allFields = Collections.unmodifiableList(allFields);
|
List<ProcessedField> processedFields,
|
||||||
|
Map<String, Long> cardinalitiesForFieldsWithConstraints) {
|
||||||
|
this.allFields = new ArrayList<>(allFields);
|
||||||
this.docValueFields = filterFields(ExtractedField.Method.DOC_VALUE, allFields);
|
this.docValueFields = filterFields(ExtractedField.Method.DOC_VALUE, allFields);
|
||||||
this.sourceFields = filterFields(ExtractedField.Method.SOURCE, allFields).stream().map(ExtractedField::getSearchField)
|
this.sourceFields = filterFields(ExtractedField.Method.SOURCE, allFields).stream().map(ExtractedField::getSearchField)
|
||||||
.toArray(String[]::new);
|
.toArray(String[]::new);
|
||||||
this.cardinalitiesForFieldsWithConstraints = Collections.unmodifiableMap(cardinalitiesForFieldsWithConstraints);
|
this.cardinalitiesForFieldsWithConstraints = Collections.unmodifiableMap(cardinalitiesForFieldsWithConstraints);
|
||||||
|
this.processedFields = processedFields == null ? Collections.emptyList() : processedFields;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<ProcessedField> getProcessedFields() {
|
||||||
|
return processedFields;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<ExtractedField> getAllFields() {
|
public List<ExtractedField> getAllFields() {
|
||||||
return allFields;
|
return allFields;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Set<String> getProcessedFieldInputs() {
|
||||||
|
return processedFields.stream().map(ProcessedField::getInputFieldNames).flatMap(List::stream).collect(Collectors.toSet());
|
||||||
|
}
|
||||||
|
|
||||||
public String[] getSourceFields() {
|
public String[] getSourceFields() {
|
||||||
return sourceFields;
|
return sourceFields;
|
||||||
}
|
}
|
||||||
|
@ -58,11 +70,15 @@ public class ExtractedFields {
|
||||||
return fields.stream().filter(field -> field.getMethod() == method).collect(Collectors.toList());
|
return fields.stream().filter(field -> field.getMethod() == method).collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
public static ExtractedFields build(Collection<String> allFields, Set<String> scriptFields,
|
public static ExtractedFields build(Set<String> allFields,
|
||||||
|
Set<String> scriptFields,
|
||||||
FieldCapabilitiesResponse fieldsCapabilities,
|
FieldCapabilitiesResponse fieldsCapabilities,
|
||||||
Map<String, Long> cardinalitiesForFieldsWithConstraints) {
|
Map<String, Long> cardinalitiesForFieldsWithConstraints,
|
||||||
|
List<ProcessedField> processedFields) {
|
||||||
ExtractionMethodDetector extractionMethodDetector = new ExtractionMethodDetector(scriptFields, fieldsCapabilities);
|
ExtractionMethodDetector extractionMethodDetector = new ExtractionMethodDetector(scriptFields, fieldsCapabilities);
|
||||||
return new ExtractedFields(allFields.stream().map(field -> extractionMethodDetector.detect(field)).collect(Collectors.toList()),
|
return new ExtractedFields(
|
||||||
|
allFields.stream().map(extractionMethodDetector::detect).collect(Collectors.toList()),
|
||||||
|
processedFields,
|
||||||
cardinalitiesForFieldsWithConstraints);
|
cardinalitiesForFieldsWithConstraints);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.xpack.ml.extractor;
|
||||||
|
|
||||||
|
import org.elasticsearch.search.SearchHit;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.function.Function;
|
||||||
|
|
||||||
|
public class ProcessedField {
|
||||||
|
private final PreProcessor preProcessor;
|
||||||
|
|
||||||
|
public ProcessedField(PreProcessor processor) {
|
||||||
|
this.preProcessor = Objects.requireNonNull(processor);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<String> getInputFieldNames() {
|
||||||
|
return preProcessor.inputFields();
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<String> getOutputFieldNames() {
|
||||||
|
return preProcessor.outputFields();
|
||||||
|
}
|
||||||
|
|
||||||
|
public Set<String> getOutputFieldType(String outputField) {
|
||||||
|
return Collections.singleton(preProcessor.getOutputFieldType(outputField));
|
||||||
|
}
|
||||||
|
|
||||||
|
public Object[] value(SearchHit hit, Function<String, ExtractedField> fieldExtractor) {
|
||||||
|
Map<String, Object> inputs = new HashMap<>(preProcessor.inputFields().size(), 1.0f);
|
||||||
|
for (String field : preProcessor.inputFields()) {
|
||||||
|
ExtractedField extractedField = fieldExtractor.apply(field);
|
||||||
|
if (extractedField == null) {
|
||||||
|
return new Object[0];
|
||||||
|
}
|
||||||
|
Object[] values = extractedField.value(hit);
|
||||||
|
if (values == null || values.length == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
final Object value = values[0];
|
||||||
|
if (values.length == 1 && (value instanceof String || value instanceof Number)) {
|
||||||
|
inputs.put(field, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
preProcessor.process(inputs);
|
||||||
|
return preProcessor.outputFields().stream().map(inputs::get).toArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getProcessorName() {
|
||||||
|
return preProcessor.getName();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -128,4 +128,11 @@ public abstract class MlSingleNodeTestCase extends ESSingleNodeTestCase {
|
||||||
return responseHolder.get();
|
return responseHolder.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static void assertNoException(AtomicReference<Exception> error) throws Exception {
|
||||||
|
if (error.get() == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
throw error.get();
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,8 +15,10 @@ import org.elasticsearch.action.search.SearchResponse;
|
||||||
import org.elasticsearch.action.search.ShardSearchFailure;
|
import org.elasticsearch.action.search.ShardSearchFailure;
|
||||||
import org.elasticsearch.client.Client;
|
import org.elasticsearch.client.Client;
|
||||||
import org.elasticsearch.common.Nullable;
|
import org.elasticsearch.common.Nullable;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
import org.elasticsearch.common.settings.Settings;
|
import org.elasticsearch.common.settings.Settings;
|
||||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.index.query.QueryBuilder;
|
import org.elasticsearch.index.query.QueryBuilder;
|
||||||
import org.elasticsearch.index.query.QueryBuilders;
|
import org.elasticsearch.index.query.QueryBuilders;
|
||||||
import org.elasticsearch.rest.RestStatus;
|
import org.elasticsearch.rest.RestStatus;
|
||||||
|
@ -27,10 +29,13 @@ import org.elasticsearch.threadpool.ThreadPool;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory;
|
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory;
|
||||||
import org.elasticsearch.xpack.ml.extractor.DocValueField;
|
import org.elasticsearch.xpack.ml.extractor.DocValueField;
|
||||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||||
|
import org.elasticsearch.xpack.ml.extractor.ProcessedField;
|
||||||
import org.elasticsearch.xpack.ml.extractor.SourceField;
|
import org.elasticsearch.xpack.ml.extractor.SourceField;
|
||||||
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
|
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
@ -45,8 +50,10 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.Queue;
|
import java.util.Queue;
|
||||||
|
import java.util.function.Function;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.contains;
|
||||||
import static org.hamcrest.Matchers.containsInAnyOrder;
|
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||||
import static org.hamcrest.Matchers.containsString;
|
import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.hamcrest.Matchers.empty;
|
import static org.hamcrest.Matchers.empty;
|
||||||
|
@ -83,7 +90,9 @@ public class DataFrameDataExtractorTests extends ESTestCase {
|
||||||
query = QueryBuilders.matchAllQuery();
|
query = QueryBuilders.matchAllQuery();
|
||||||
extractedFields = new ExtractedFields(Arrays.asList(
|
extractedFields = new ExtractedFields(Arrays.asList(
|
||||||
new DocValueField("field_1", Collections.singleton("keyword")),
|
new DocValueField("field_1", Collections.singleton("keyword")),
|
||||||
new DocValueField("field_2", Collections.singleton("keyword"))), Collections.emptyMap());
|
new DocValueField("field_2", Collections.singleton("keyword"))),
|
||||||
|
Collections.emptyList(),
|
||||||
|
Collections.emptyMap());
|
||||||
scrollSize = 1000;
|
scrollSize = 1000;
|
||||||
headers = Collections.emptyMap();
|
headers = Collections.emptyMap();
|
||||||
|
|
||||||
|
@ -304,7 +313,9 @@ public class DataFrameDataExtractorTests extends ESTestCase {
|
||||||
// Explicit cast of ExtractedField args necessary for Eclipse due to https://bugs.eclipse.org/bugs/show_bug.cgi?id=530915
|
// Explicit cast of ExtractedField args necessary for Eclipse due to https://bugs.eclipse.org/bugs/show_bug.cgi?id=530915
|
||||||
extractedFields = new ExtractedFields(Arrays.asList(
|
extractedFields = new ExtractedFields(Arrays.asList(
|
||||||
(ExtractedField) new DocValueField("field_1", Collections.singleton("keyword")),
|
(ExtractedField) new DocValueField("field_1", Collections.singleton("keyword")),
|
||||||
(ExtractedField) new SourceField("field_2", Collections.singleton("text"))), Collections.emptyMap());
|
(ExtractedField) new SourceField("field_2", Collections.singleton("text"))),
|
||||||
|
Collections.emptyList(),
|
||||||
|
Collections.emptyMap());
|
||||||
|
|
||||||
TestExtractor dataExtractor = createExtractor(false, false);
|
TestExtractor dataExtractor = createExtractor(false, false);
|
||||||
|
|
||||||
|
@ -445,7 +456,9 @@ public class DataFrameDataExtractorTests extends ESTestCase {
|
||||||
(ExtractedField) new DocValueField("field_integer", Collections.singleton("integer")),
|
(ExtractedField) new DocValueField("field_integer", Collections.singleton("integer")),
|
||||||
(ExtractedField) new DocValueField("field_long", Collections.singleton("long")),
|
(ExtractedField) new DocValueField("field_long", Collections.singleton("long")),
|
||||||
(ExtractedField) new DocValueField("field_keyword", Collections.singleton("keyword")),
|
(ExtractedField) new DocValueField("field_keyword", Collections.singleton("keyword")),
|
||||||
(ExtractedField) new SourceField("field_text", Collections.singleton("text"))), Collections.emptyMap());
|
(ExtractedField) new SourceField("field_text", Collections.singleton("text"))),
|
||||||
|
Collections.emptyList(),
|
||||||
|
Collections.emptyMap());
|
||||||
TestExtractor dataExtractor = createExtractor(true, true);
|
TestExtractor dataExtractor = createExtractor(true, true);
|
||||||
|
|
||||||
assertThat(dataExtractor.getCategoricalFields(OutlierDetectionTests.createRandom()), empty());
|
assertThat(dataExtractor.getCategoricalFields(OutlierDetectionTests.createRandom()), empty());
|
||||||
|
@ -465,12 +478,100 @@ public class DataFrameDataExtractorTests extends ESTestCase {
|
||||||
containsInAnyOrder("field_keyword", "field_text", "field_boolean"));
|
containsInAnyOrder("field_keyword", "field_text", "field_boolean"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testGetFieldNames_GivenProcessesFeatures() {
|
||||||
|
// Explicit cast of ExtractedField args necessary for Eclipse due to https://bugs.eclipse.org/bugs/show_bug.cgi?id=530915
|
||||||
|
extractedFields = new ExtractedFields(Arrays.asList(
|
||||||
|
(ExtractedField) new DocValueField("field_boolean", Collections.singleton("boolean")),
|
||||||
|
(ExtractedField) new DocValueField("field_float", Collections.singleton("float")),
|
||||||
|
(ExtractedField) new DocValueField("field_double", Collections.singleton("double")),
|
||||||
|
(ExtractedField) new DocValueField("field_byte", Collections.singleton("byte")),
|
||||||
|
(ExtractedField) new DocValueField("field_short", Collections.singleton("short")),
|
||||||
|
(ExtractedField) new DocValueField("field_integer", Collections.singleton("integer")),
|
||||||
|
(ExtractedField) new DocValueField("field_long", Collections.singleton("long")),
|
||||||
|
(ExtractedField) new DocValueField("field_keyword", Collections.singleton("keyword")),
|
||||||
|
(ExtractedField) new SourceField("field_text", Collections.singleton("text"))),
|
||||||
|
Arrays.asList(
|
||||||
|
new ProcessedField(new CategoricalPreProcessor("field_long", "animal")),
|
||||||
|
buildProcessedField("field_short", "field_1", "field_2")
|
||||||
|
),
|
||||||
|
Collections.emptyMap());
|
||||||
|
TestExtractor dataExtractor = createExtractor(true, true);
|
||||||
|
|
||||||
|
assertThat(dataExtractor.getCategoricalFields(new Regression("field_double")),
|
||||||
|
containsInAnyOrder("field_keyword", "field_text", "animal"));
|
||||||
|
|
||||||
|
List<String> fieldNames = dataExtractor.getFieldNames();
|
||||||
|
assertThat(fieldNames, containsInAnyOrder(
|
||||||
|
"animal",
|
||||||
|
"field_1",
|
||||||
|
"field_2",
|
||||||
|
"field_boolean",
|
||||||
|
"field_float",
|
||||||
|
"field_double",
|
||||||
|
"field_byte",
|
||||||
|
"field_integer",
|
||||||
|
"field_keyword",
|
||||||
|
"field_text"));
|
||||||
|
assertThat(dataExtractor.getFieldNames(), contains(fieldNames.toArray(new String[0])));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testExtractionWithProcessedFeatures() throws IOException {
|
||||||
|
extractedFields = new ExtractedFields(Arrays.asList(
|
||||||
|
new DocValueField("field_1", Collections.singleton("keyword")),
|
||||||
|
new DocValueField("field_2", Collections.singleton("keyword"))),
|
||||||
|
Arrays.asList(
|
||||||
|
new ProcessedField(new CategoricalPreProcessor("field_1", "animal")),
|
||||||
|
new ProcessedField(new OneHotEncoding("field_1",
|
||||||
|
Arrays.asList("11", "12")
|
||||||
|
.stream()
|
||||||
|
.collect(Collectors.toMap(Function.identity(), s -> s.equals("11") ? "field_11" : "field_12")),
|
||||||
|
true))
|
||||||
|
),
|
||||||
|
Collections.emptyMap());
|
||||||
|
|
||||||
|
TestExtractor dataExtractor = createExtractor(true, true);
|
||||||
|
|
||||||
|
// First and only batch
|
||||||
|
SearchResponse response1 = createSearchResponse(Arrays.asList(1_1, null, 1_3), Arrays.asList(2_1, 2_2, 2_3));
|
||||||
|
dataExtractor.setNextResponse(response1);
|
||||||
|
|
||||||
|
// Empty
|
||||||
|
SearchResponse lastAndEmptyResponse = createEmptySearchResponse();
|
||||||
|
dataExtractor.setNextResponse(lastAndEmptyResponse);
|
||||||
|
|
||||||
|
assertThat(dataExtractor.hasNext(), is(true));
|
||||||
|
|
||||||
|
// First batch
|
||||||
|
Optional<List<DataFrameDataExtractor.Row>> rows = dataExtractor.next();
|
||||||
|
assertThat(rows.isPresent(), is(true));
|
||||||
|
assertThat(rows.get().size(), equalTo(3));
|
||||||
|
|
||||||
|
assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"21", "dog", "1", "0"}));
|
||||||
|
assertThat(rows.get().get(1).getValues(),
|
||||||
|
equalTo(new String[] {"22", "dog", DataFrameDataExtractor.NULL_VALUE, DataFrameDataExtractor.NULL_VALUE}));
|
||||||
|
assertThat(rows.get().get(2).getValues(), equalTo(new String[] {"23", "dog", "0", "0"}));
|
||||||
|
|
||||||
|
assertThat(rows.get().get(0).shouldSkip(), is(false));
|
||||||
|
assertThat(rows.get().get(1).shouldSkip(), is(false));
|
||||||
|
assertThat(rows.get().get(2).shouldSkip(), is(false));
|
||||||
|
}
|
||||||
|
|
||||||
private TestExtractor createExtractor(boolean includeSource, boolean supportsRowsWithMissingValues) {
|
private TestExtractor createExtractor(boolean includeSource, boolean supportsRowsWithMissingValues) {
|
||||||
DataFrameDataExtractorContext context = new DataFrameDataExtractorContext(JOB_ID, extractedFields, indices, query, scrollSize,
|
DataFrameDataExtractorContext context = new DataFrameDataExtractorContext(JOB_ID, extractedFields, indices, query, scrollSize,
|
||||||
headers, includeSource, supportsRowsWithMissingValues, trainTestSplitterFactory);
|
headers, includeSource, supportsRowsWithMissingValues, trainTestSplitterFactory);
|
||||||
return new TestExtractor(client, context);
|
return new TestExtractor(client, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static ProcessedField buildProcessedField(String inputField, String... outputFields) {
|
||||||
|
return new ProcessedField(buildPreProcessor(inputField, outputFields));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static PreProcessor buildPreProcessor(String inputField, String... outputFields) {
|
||||||
|
return new OneHotEncoding(inputField,
|
||||||
|
Arrays.stream(outputFields).collect(Collectors.toMap((s) -> randomAlphaOfLength(10), Function.identity())),
|
||||||
|
true);
|
||||||
|
}
|
||||||
|
|
||||||
private SearchResponse createSearchResponse(List<Number> field1Values, List<Number> field2Values) {
|
private SearchResponse createSearchResponse(List<Number> field1Values, List<Number> field2Values) {
|
||||||
assertThat(field1Values.size(), equalTo(field2Values.size()));
|
assertThat(field1Values.size(), equalTo(field2Values.size()));
|
||||||
SearchResponse searchResponse = mock(SearchResponse.class);
|
SearchResponse searchResponse = mock(SearchResponse.class);
|
||||||
|
@ -544,4 +645,70 @@ public class DataFrameDataExtractorTests extends ESTestCase {
|
||||||
return searchResponse;
|
return searchResponse;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static class CategoricalPreProcessor implements PreProcessor {
|
||||||
|
|
||||||
|
private final List<String> inputFields;
|
||||||
|
private final List<String> outputFields;
|
||||||
|
|
||||||
|
CategoricalPreProcessor(String inputField, String outputField) {
|
||||||
|
this.inputFields = Arrays.asList(inputField);
|
||||||
|
this.outputFields = Arrays.asList(outputField);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> inputFields() {
|
||||||
|
return inputFields;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> outputFields() {
|
||||||
|
return outputFields;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void process(Map<String, Object> fields) {
|
||||||
|
fields.put(outputFields.get(0), "dog");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, String> reverseLookup() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isCustom() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getOutputFieldType(String outputField) {
|
||||||
|
return "text";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long ramBytesUsed() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getWriteableName() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getName() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,10 +15,13 @@ import org.elasticsearch.test.ESTestCase;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection;
|
import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||||
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
|
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
|
||||||
|
@ -30,11 +33,14 @@ import java.util.HashMap;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.function.Function;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.arrayContaining;
|
import static org.hamcrest.Matchers.arrayContaining;
|
||||||
|
import static org.hamcrest.Matchers.arrayContainingInAnyOrder;
|
||||||
import static org.hamcrest.Matchers.contains;
|
import static org.hamcrest.Matchers.contains;
|
||||||
import static org.hamcrest.Matchers.containsInAnyOrder;
|
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||||
|
import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.hasSize;
|
import static org.hamcrest.Matchers.hasSize;
|
||||||
import static org.hamcrest.Matchers.is;
|
import static org.hamcrest.Matchers.is;
|
||||||
|
@ -929,12 +935,23 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
||||||
assertThat(e.getMessage(), equalTo("analyzed_fields must not include or exclude object fields: [object_field]"));
|
assertThat(e.getMessage(), equalTo("analyzed_fields must not include or exclude object fields: [object_field]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static FieldCapabilitiesResponse simpleFieldResponse() {
|
||||||
|
return new MockFieldCapsResponseBuilder()
|
||||||
|
.addAggregatableField("field_11", "float")
|
||||||
|
.addNonAggregatableField("field_21", "float")
|
||||||
|
.addAggregatableField("field_21.child", "float")
|
||||||
|
.addNonAggregatableField("field_31", "float")
|
||||||
|
.addAggregatableField("field_31.child", "float")
|
||||||
|
.addNonAggregatableField("object_field", "object")
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
public void testDetect_GivenAnalyzedFieldExcludesObjectField() {
|
public void testDetect_GivenAnalyzedFieldExcludesObjectField() {
|
||||||
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
|
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
|
||||||
.addAggregatableField("float_field", "float")
|
.addAggregatableField("float_field", "float")
|
||||||
.addNonAggregatableField("object_field", "object").build();
|
.addNonAggregatableField("object_field", "object").build();
|
||||||
|
|
||||||
analyzedFields = new FetchSourceContext(true, null, new String[] { "object_field" });
|
analyzedFields = new FetchSourceContext(true, null, new String[]{"object_field"});
|
||||||
|
|
||||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||||
buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap());
|
buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap());
|
||||||
|
@ -943,6 +960,177 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
||||||
assertThat(e.getMessage(), equalTo("analyzed_fields must not include or exclude object fields: [object_field]"));
|
assertThat(e.getMessage(), equalTo("analyzed_fields must not include or exclude object fields: [object_field]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testDetect_givenFeatureProcessorsFailures_ResultsField() {
|
||||||
|
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
|
||||||
|
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||||
|
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("ml.result", "foo"))),
|
||||||
|
100,
|
||||||
|
fieldCapabilities,
|
||||||
|
Collections.emptyMap());
|
||||||
|
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||||
|
assertThat(ex.getMessage(),
|
||||||
|
containsString("fields contained in results field [ml] cannot be used in a feature_processor"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testDetect_givenFeatureProcessorsFailures_Objects() {
|
||||||
|
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
|
||||||
|
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||||
|
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("object_field", "foo"))),
|
||||||
|
100,
|
||||||
|
fieldCapabilities,
|
||||||
|
Collections.emptyMap());
|
||||||
|
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||||
|
assertThat(ex.getMessage(),
|
||||||
|
containsString("fields for feature_processors must not be objects"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testDetect_givenFeatureProcessorsFailures_ReservedFields() {
|
||||||
|
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
|
||||||
|
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||||
|
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("_id", "foo"))),
|
||||||
|
100,
|
||||||
|
fieldCapabilities,
|
||||||
|
Collections.emptyMap());
|
||||||
|
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||||
|
assertThat(ex.getMessage(),
|
||||||
|
containsString("the following fields cannot be used in feature_processors"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testDetect_givenFeatureProcessorsFailures_MissingFieldFromIndex() {
|
||||||
|
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
|
||||||
|
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||||
|
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("bar", "foo"))),
|
||||||
|
100,
|
||||||
|
fieldCapabilities,
|
||||||
|
Collections.emptyMap());
|
||||||
|
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||||
|
assertThat(ex.getMessage(),
|
||||||
|
containsString("the fields [bar] were not found in the field capabilities of the source indices"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testDetect_givenFeatureProcessorsFailures_UsingRequiredField() {
|
||||||
|
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
|
||||||
|
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||||
|
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("field_31", "foo"))),
|
||||||
|
100,
|
||||||
|
fieldCapabilities,
|
||||||
|
Collections.emptyMap());
|
||||||
|
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||||
|
assertThat(ex.getMessage(),
|
||||||
|
containsString("required analysis fields [field_31] cannot be used in a feature_processor"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testDetect_givenFeatureProcessorsFailures_BadSourceFiltering() {
|
||||||
|
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
|
||||||
|
sourceFiltering = new FetchSourceContext(true, null, new String[]{"field_1*"});
|
||||||
|
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||||
|
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("field_11", "foo"))),
|
||||||
|
100,
|
||||||
|
fieldCapabilities,
|
||||||
|
Collections.emptyMap());
|
||||||
|
|
||||||
|
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||||
|
assertThat(ex.getMessage(),
|
||||||
|
containsString("fields [field_11] required by field_processors are not included in source filtering."));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testDetect_givenFeatureProcessorsFailures_MissingAnalyzedField() {
|
||||||
|
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
|
||||||
|
analyzedFields = new FetchSourceContext(true, null, new String[]{"field_1*"});
|
||||||
|
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||||
|
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("field_11", "foo"))),
|
||||||
|
100,
|
||||||
|
fieldCapabilities,
|
||||||
|
Collections.emptyMap());
|
||||||
|
|
||||||
|
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||||
|
assertThat(ex.getMessage(),
|
||||||
|
containsString("fields [field_11] required by field_processors are not included in the analyzed_fields"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testDetect_givenFeatureProcessorsFailures_RequiredMultiFields() {
|
||||||
|
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
|
||||||
|
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||||
|
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("field_31.child", "foo"))),
|
||||||
|
100,
|
||||||
|
fieldCapabilities,
|
||||||
|
Collections.emptyMap());
|
||||||
|
|
||||||
|
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||||
|
assertThat(ex.getMessage(),
|
||||||
|
containsString("feature_processors cannot be applied to required fields for analysis; "));
|
||||||
|
|
||||||
|
extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||||
|
buildRegressionConfig("field_31.child", Arrays.asList(buildPreProcessor("field_31", "foo"))),
|
||||||
|
100,
|
||||||
|
fieldCapabilities,
|
||||||
|
Collections.emptyMap());
|
||||||
|
|
||||||
|
ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||||
|
assertThat(ex.getMessage(),
|
||||||
|
containsString("feature_processors cannot be applied to required fields for analysis; "));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testDetect_givenFeatureProcessorsFailures_BothMultiFields() {
|
||||||
|
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
|
||||||
|
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||||
|
buildRegressionConfig("field_31",
|
||||||
|
Arrays.asList(
|
||||||
|
buildPreProcessor("field_21", "foo"),
|
||||||
|
buildPreProcessor("field_21.child", "bar")
|
||||||
|
)),
|
||||||
|
100,
|
||||||
|
fieldCapabilities,
|
||||||
|
Collections.emptyMap());
|
||||||
|
|
||||||
|
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||||
|
assertThat(ex.getMessage(),
|
||||||
|
containsString("feature_processors refer to both multi-field "));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testDetect_givenFeatureProcessorsFailures_DuplicateOutputFields() {
|
||||||
|
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
|
||||||
|
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||||
|
buildRegressionConfig("field_31",
|
||||||
|
Arrays.asList(
|
||||||
|
buildPreProcessor("field_11", "foo"),
|
||||||
|
buildPreProcessor("field_21", "foo")
|
||||||
|
)),
|
||||||
|
100,
|
||||||
|
fieldCapabilities,
|
||||||
|
Collections.emptyMap());
|
||||||
|
|
||||||
|
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||||
|
assertThat(ex.getMessage(),
|
||||||
|
containsString("feature_processors must define unique output field names; duplicate fields [foo]"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testDetect_withFeatureProcessors() {
|
||||||
|
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
|
||||||
|
.addAggregatableField("field_11", "float")
|
||||||
|
.addAggregatableField("field_21", "float")
|
||||||
|
.addNonAggregatableField("field_31", "float")
|
||||||
|
.addAggregatableField("field_31.child", "float")
|
||||||
|
.addNonAggregatableField("object_field", "object")
|
||||||
|
.build();
|
||||||
|
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||||
|
buildRegressionConfig("field_11",
|
||||||
|
Arrays.asList(buildPreProcessor("field_31", "foo", "bar"))),
|
||||||
|
100,
|
||||||
|
fieldCapabilities,
|
||||||
|
Collections.emptyMap());
|
||||||
|
|
||||||
|
ExtractedFields extracted = extractedFieldsDetector.detect().v1();
|
||||||
|
|
||||||
|
assertThat(extracted.getProcessedFieldInputs(), containsInAnyOrder("field_31"));
|
||||||
|
assertThat(extracted.getAllFields().stream().map(ExtractedField::getName).collect(Collectors.toSet()),
|
||||||
|
containsInAnyOrder("field_11", "field_21", "field_31"));
|
||||||
|
assertThat(extracted.getSourceFields(), arrayContainingInAnyOrder("field_31"));
|
||||||
|
assertThat(extracted.getDocValueFields().stream().map(ExtractedField::getName).collect(Collectors.toSet()),
|
||||||
|
containsInAnyOrder("field_21", "field_11"));
|
||||||
|
assertThat(extracted.getProcessedFields(), hasSize(1));
|
||||||
|
}
|
||||||
|
|
||||||
private DataFrameAnalyticsConfig buildOutlierDetectionConfig() {
|
private DataFrameAnalyticsConfig buildOutlierDetectionConfig() {
|
||||||
return new DataFrameAnalyticsConfig.Builder()
|
return new DataFrameAnalyticsConfig.Builder()
|
||||||
.setId("foo")
|
.setId("foo")
|
||||||
|
@ -954,13 +1142,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
private DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable) {
|
private DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable) {
|
||||||
return new DataFrameAnalyticsConfig.Builder()
|
return buildRegressionConfig(dependentVariable, Collections.emptyList());
|
||||||
.setId("foo")
|
|
||||||
.setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null, sourceFiltering))
|
|
||||||
.setDest(new DataFrameAnalyticsDest(DEST_INDEX, RESULTS_FIELD))
|
|
||||||
.setAnalyzedFields(analyzedFields)
|
|
||||||
.setAnalysis(new Regression(dependentVariable))
|
|
||||||
.build();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private DataFrameAnalyticsConfig buildClassificationConfig(String dependentVariable) {
|
private DataFrameAnalyticsConfig buildClassificationConfig(String dependentVariable) {
|
||||||
|
@ -972,6 +1154,29 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable, List<PreProcessor> featureprocessors) {
|
||||||
|
return new DataFrameAnalyticsConfig.Builder()
|
||||||
|
.setId("foo")
|
||||||
|
.setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null, sourceFiltering))
|
||||||
|
.setDest(new DataFrameAnalyticsDest(DEST_INDEX, RESULTS_FIELD))
|
||||||
|
.setAnalyzedFields(analyzedFields)
|
||||||
|
.setAnalysis(new Regression(dependentVariable,
|
||||||
|
BoostedTreeParams.builder().build(),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
featureprocessors))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static PreProcessor buildPreProcessor(String inputField, String... outputFields) {
|
||||||
|
return new OneHotEncoding(inputField,
|
||||||
|
Arrays.stream(outputFields).collect(Collectors.toMap((s) -> randomAlphaOfLength(10), Function.identity())),
|
||||||
|
true);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* We assert each field individually to get useful error messages in case of failure
|
* We assert each field individually to get useful error messages in case of failure
|
||||||
*/
|
*/
|
||||||
|
@ -987,6 +1192,23 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testDetect_givenFeatureProcessorsFailures_DuplicateOutputFieldsWithUnProcessedField() {
|
||||||
|
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
|
||||||
|
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||||
|
buildRegressionConfig("field_31",
|
||||||
|
Arrays.asList(
|
||||||
|
buildPreProcessor("field_11", "field_21")
|
||||||
|
)),
|
||||||
|
100,
|
||||||
|
fieldCapabilities,
|
||||||
|
Collections.emptyMap());
|
||||||
|
|
||||||
|
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||||
|
assertThat(ex.getMessage(),
|
||||||
|
containsString(
|
||||||
|
"feature_processors output fields must not include non-processed analysis fields; duplicate fields [field_21]"));
|
||||||
|
}
|
||||||
|
|
||||||
private static class MockFieldCapsResponseBuilder {
|
private static class MockFieldCapsResponseBuilder {
|
||||||
|
|
||||||
private final Map<String, Map<String, FieldCapabilities>> fieldCaps = new HashMap<>();
|
private final Map<String, Map<String, FieldCapabilities>> fieldCaps = new HashMap<>();
|
||||||
|
|
|
@ -80,6 +80,7 @@ public class InferenceRunnerTests extends ESTestCase {
|
||||||
public void testInferTestDocs() {
|
public void testInferTestDocs() {
|
||||||
ExtractedFields extractedFields = new ExtractedFields(
|
ExtractedFields extractedFields = new ExtractedFields(
|
||||||
Collections.singletonList(new SourceField("key", Collections.singleton("integer"))),
|
Collections.singletonList(new SourceField("key", Collections.singleton("integer"))),
|
||||||
|
Collections.emptyList(),
|
||||||
Collections.emptyMap());
|
Collections.emptyMap());
|
||||||
|
|
||||||
Map<String, Object> doc1 = new HashMap<>();
|
Map<String, Object> doc1 = new HashMap<>();
|
||||||
|
|
|
@ -63,7 +63,9 @@ public class AnalyticsProcessConfigTests extends ESTestCase {
|
||||||
public void testToXContent_GivenOutlierDetection() throws IOException {
|
public void testToXContent_GivenOutlierDetection() throws IOException {
|
||||||
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
|
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
|
||||||
new DocValueField("field_1", Collections.singleton("double")),
|
new DocValueField("field_1", Collections.singleton("double")),
|
||||||
new DocValueField("field_2", Collections.singleton("float"))), Collections.emptyMap());
|
new DocValueField("field_2", Collections.singleton("float"))),
|
||||||
|
Collections.emptyList(),
|
||||||
|
Collections.emptyMap());
|
||||||
DataFrameAnalysis analysis = new OutlierDetection.Builder().build();
|
DataFrameAnalysis analysis = new OutlierDetection.Builder().build();
|
||||||
|
|
||||||
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
|
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
|
||||||
|
@ -82,7 +84,9 @@ public class AnalyticsProcessConfigTests extends ESTestCase {
|
||||||
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
|
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
|
||||||
new DocValueField("field_1", Collections.singleton("double")),
|
new DocValueField("field_1", Collections.singleton("double")),
|
||||||
new DocValueField("field_2", Collections.singleton("float")),
|
new DocValueField("field_2", Collections.singleton("float")),
|
||||||
new DocValueField("test_dep_var", Collections.singleton("keyword"))), Collections.emptyMap());
|
new DocValueField("test_dep_var", Collections.singleton("keyword"))),
|
||||||
|
Collections.emptyList(),
|
||||||
|
Collections.emptyMap());
|
||||||
DataFrameAnalysis analysis = new Regression("test_dep_var");
|
DataFrameAnalysis analysis = new Regression("test_dep_var");
|
||||||
|
|
||||||
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
|
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
|
||||||
|
@ -103,7 +107,9 @@ public class AnalyticsProcessConfigTests extends ESTestCase {
|
||||||
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
|
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
|
||||||
new DocValueField("field_1", Collections.singleton("double")),
|
new DocValueField("field_1", Collections.singleton("double")),
|
||||||
new DocValueField("field_2", Collections.singleton("float")),
|
new DocValueField("field_2", Collections.singleton("float")),
|
||||||
new DocValueField("test_dep_var", Collections.singleton("keyword"))), Collections.singletonMap("test_dep_var", 5L));
|
new DocValueField("test_dep_var", Collections.singleton("keyword"))),
|
||||||
|
Collections.emptyList(),
|
||||||
|
Collections.singletonMap("test_dep_var", 5L));
|
||||||
DataFrameAnalysis analysis = new Classification("test_dep_var");
|
DataFrameAnalysis analysis = new Classification("test_dep_var");
|
||||||
|
|
||||||
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
|
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
|
||||||
|
@ -126,7 +132,9 @@ public class AnalyticsProcessConfigTests extends ESTestCase {
|
||||||
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
|
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
|
||||||
new DocValueField("field_1", Collections.singleton("double")),
|
new DocValueField("field_1", Collections.singleton("double")),
|
||||||
new DocValueField("field_2", Collections.singleton("float")),
|
new DocValueField("field_2", Collections.singleton("float")),
|
||||||
new DocValueField("test_dep_var", Collections.singleton("integer"))), Collections.singletonMap("test_dep_var", 8L));
|
new DocValueField("test_dep_var", Collections.singleton("integer"))),
|
||||||
|
Collections.emptyList(),
|
||||||
|
Collections.singletonMap("test_dep_var", 8L));
|
||||||
DataFrameAnalysis analysis = new Classification("test_dep_var");
|
DataFrameAnalysis analysis = new Classification("test_dep_var");
|
||||||
|
|
||||||
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
|
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
|
||||||
|
|
|
@ -105,7 +105,9 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
|
||||||
OutlierDetectionTests.createRandom()).build();
|
OutlierDetectionTests.createRandom()).build();
|
||||||
dataExtractor = mock(DataFrameDataExtractor.class);
|
dataExtractor = mock(DataFrameDataExtractor.class);
|
||||||
when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS));
|
when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS));
|
||||||
when(dataExtractor.getExtractedFields()).thenReturn(new ExtractedFields(Collections.emptyList(), Collections.emptyMap()));
|
when(dataExtractor.getExtractedFields()).thenReturn(new ExtractedFields(Collections.emptyList(),
|
||||||
|
Collections.emptyList(),
|
||||||
|
Collections.emptyMap()));
|
||||||
dataExtractorFactory = mock(DataFrameDataExtractorFactory.class);
|
dataExtractorFactory = mock(DataFrameDataExtractorFactory.class);
|
||||||
when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor);
|
when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor);
|
||||||
when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class));
|
when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class));
|
||||||
|
|
|
@ -314,6 +314,6 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
||||||
trainedModelProvider,
|
trainedModelProvider,
|
||||||
auditor,
|
auditor,
|
||||||
statsPersister,
|
statsPersister,
|
||||||
new ExtractedFields(fieldNames, Collections.emptyMap()));
|
new ExtractedFields(fieldNames, Collections.emptyList(), Collections.emptyMap()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -144,7 +144,7 @@ public class ChunkedTrainedModelPersisterTests extends ESTestCase {
|
||||||
analyticsConfig,
|
analyticsConfig,
|
||||||
auditor,
|
auditor,
|
||||||
(unused)->{},
|
(unused)->{},
|
||||||
new ExtractedFields(fieldNames, Collections.emptyMap()));
|
new ExtractedFields(fieldNames, Collections.emptyList(), Collections.emptyMap()));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@ import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.TreeSet;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.is;
|
import static org.hamcrest.Matchers.is;
|
||||||
|
@ -31,8 +32,10 @@ public class ExtractedFieldsTests extends ESTestCase {
|
||||||
ExtractedField scriptField2 = new ScriptField("scripted2");
|
ExtractedField scriptField2 = new ScriptField("scripted2");
|
||||||
ExtractedField sourceField1 = new SourceField("src1", Collections.singleton("text"));
|
ExtractedField sourceField1 = new SourceField("src1", Collections.singleton("text"));
|
||||||
ExtractedField sourceField2 = new SourceField("src2", Collections.singleton("text"));
|
ExtractedField sourceField2 = new SourceField("src2", Collections.singleton("text"));
|
||||||
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
|
ExtractedFields extractedFields = new ExtractedFields(
|
||||||
docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2), Collections.emptyMap());
|
Arrays.asList(docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2),
|
||||||
|
Collections.emptyList(),
|
||||||
|
Collections.emptyMap());
|
||||||
|
|
||||||
assertThat(extractedFields.getAllFields().size(), equalTo(6));
|
assertThat(extractedFields.getAllFields().size(), equalTo(6));
|
||||||
assertThat(extractedFields.getDocValueFields().stream().map(ExtractedField::getName).toArray(String[]::new),
|
assertThat(extractedFields.getDocValueFields().stream().map(ExtractedField::getName).toArray(String[]::new),
|
||||||
|
@ -53,8 +56,11 @@ public class ExtractedFieldsTests extends ESTestCase {
|
||||||
when(fieldCapabilitiesResponse.getField("value")).thenReturn(valueCaps);
|
when(fieldCapabilitiesResponse.getField("value")).thenReturn(valueCaps);
|
||||||
when(fieldCapabilitiesResponse.getField("airline")).thenReturn(airlineCaps);
|
when(fieldCapabilitiesResponse.getField("airline")).thenReturn(airlineCaps);
|
||||||
|
|
||||||
ExtractedFields extractedFields = ExtractedFields.build(Arrays.asList("time", "value", "airline", "airport"),
|
ExtractedFields extractedFields = ExtractedFields.build(new TreeSet<>(Arrays.asList("time", "value", "airline", "airport")),
|
||||||
new HashSet<>(Collections.singletonList("airport")), fieldCapabilitiesResponse, Collections.emptyMap());
|
new HashSet<>(Collections.singletonList("airport")),
|
||||||
|
fieldCapabilitiesResponse,
|
||||||
|
Collections.emptyMap(),
|
||||||
|
Collections.emptyList());
|
||||||
|
|
||||||
assertThat(extractedFields.getDocValueFields().size(), equalTo(2));
|
assertThat(extractedFields.getDocValueFields().size(), equalTo(2));
|
||||||
assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("time"));
|
assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("time"));
|
||||||
|
@ -76,8 +82,8 @@ public class ExtractedFieldsTests extends ESTestCase {
|
||||||
when(fieldCapabilitiesResponse.getField("airport")).thenReturn(text);
|
when(fieldCapabilitiesResponse.getField("airport")).thenReturn(text);
|
||||||
when(fieldCapabilitiesResponse.getField("airport.keyword")).thenReturn(keyword);
|
when(fieldCapabilitiesResponse.getField("airport.keyword")).thenReturn(keyword);
|
||||||
|
|
||||||
ExtractedFields extractedFields = ExtractedFields.build(Arrays.asList("airline.text", "airport.keyword"),
|
ExtractedFields extractedFields = ExtractedFields.build(new TreeSet<>(Arrays.asList("airline.text", "airport.keyword")),
|
||||||
Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap());
|
Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap(), Collections.emptyList());
|
||||||
|
|
||||||
assertThat(extractedFields.getDocValueFields().size(), equalTo(1));
|
assertThat(extractedFields.getDocValueFields().size(), equalTo(1));
|
||||||
assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("airport.keyword"));
|
assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("airport.keyword"));
|
||||||
|
@ -112,14 +118,18 @@ public class ExtractedFieldsTests extends ESTestCase {
|
||||||
assertThat(mapped.getName(), equalTo(aBool.getName()));
|
assertThat(mapped.getName(), equalTo(aBool.getName()));
|
||||||
assertThat(mapped.getMethod(), equalTo(aBool.getMethod()));
|
assertThat(mapped.getMethod(), equalTo(aBool.getMethod()));
|
||||||
assertThat(mapped.supportsFromSource(), is(false));
|
assertThat(mapped.supportsFromSource(), is(false));
|
||||||
expectThrows(UnsupportedOperationException.class, () -> mapped.newFromSource());
|
expectThrows(UnsupportedOperationException.class, mapped::newFromSource);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testBuildGivenFieldWithoutMappings() {
|
public void testBuildGivenFieldWithoutMappings() {
|
||||||
FieldCapabilitiesResponse fieldCapabilitiesResponse = mock(FieldCapabilitiesResponse.class);
|
FieldCapabilitiesResponse fieldCapabilitiesResponse = mock(FieldCapabilitiesResponse.class);
|
||||||
|
|
||||||
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> ExtractedFields.build(
|
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> ExtractedFields.build(
|
||||||
Collections.singletonList("value"), Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap()));
|
Collections.singleton("value"),
|
||||||
|
Collections.emptySet(),
|
||||||
|
fieldCapabilitiesResponse,
|
||||||
|
Collections.emptyMap(),
|
||||||
|
Collections.emptyList()));
|
||||||
assertThat(e.getMessage(), equalTo("cannot retrieve field [value] because it has no mappings"));
|
assertThat(e.getMessage(), equalTo("cannot retrieve field [value] because it has no mappings"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.xpack.ml.extractor;
|
||||||
|
|
||||||
|
import org.elasticsearch.search.SearchHit;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||||
|
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.function.Function;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.arrayContaining;
|
||||||
|
import static org.hamcrest.Matchers.emptyArray;
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.hasItems;
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
|
import static org.hamcrest.Matchers.nullValue;
|
||||||
|
import static org.mockito.Matchers.any;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
public class ProcessedFieldTests extends ESTestCase {
|
||||||
|
|
||||||
|
public void testOneHotGetters() {
|
||||||
|
String inputField = "foo";
|
||||||
|
ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz"));
|
||||||
|
assertThat(processedField.getInputFieldNames(), hasItems(inputField));
|
||||||
|
assertThat(processedField.getOutputFieldNames(), hasItems("bar_column", "baz_column"));
|
||||||
|
assertThat(processedField.getOutputFieldType("bar_column"), equalTo(Collections.singleton("integer")));
|
||||||
|
assertThat(processedField.getOutputFieldType("baz_column"), equalTo(Collections.singleton("integer")));
|
||||||
|
assertThat(processedField.getProcessorName(), equalTo(OneHotEncoding.NAME.getPreferredName()));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMissingExtractor() {
|
||||||
|
String inputField = "foo";
|
||||||
|
ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz"));
|
||||||
|
assertThat(processedField.value(makeHit(), (s) -> null), emptyArray());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMissingInputValues() {
|
||||||
|
String inputField = "foo";
|
||||||
|
ExtractedField extractedField = makeExtractedField(new Object[0]);
|
||||||
|
ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz"));
|
||||||
|
assertThat(processedField.value(makeHit(), (s) -> extractedField), arrayContaining(is(nullValue()), is(nullValue())));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testProcessedField() {
|
||||||
|
ProcessedField processedField = new ProcessedField(makePreProcessor("foo", "bar", "baz"));
|
||||||
|
assertThat(processedField.value(makeHit(), (s) -> makeExtractedField(new Object[] { "bar" })), arrayContaining(1, 0));
|
||||||
|
assertThat(processedField.value(makeHit(), (s) -> makeExtractedField(new Object[] { "baz" })), arrayContaining(0, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static PreProcessor makePreProcessor(String inputField, String... expectedExtractedValues) {
|
||||||
|
return new OneHotEncoding(inputField,
|
||||||
|
Arrays.stream(expectedExtractedValues).collect(Collectors.toMap(Function.identity(), (s) -> s + "_column")),
|
||||||
|
true);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static ExtractedField makeExtractedField(Object[] value) {
|
||||||
|
ExtractedField extractedField = mock(ExtractedField.class);
|
||||||
|
when(extractedField.value(any())).thenReturn(value);
|
||||||
|
return extractedField;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static SearchHit makeHit() {
|
||||||
|
return new SearchHitBuilder(42).addField("a_keyword", "bar").build();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue