* [ML][Inference] adjusting definition object schema and validation (#47447) * [ML][Inference] adjusting definition object schema and validation * finalizing schema and fixing inference npe * addressing PR comments * fixing for backport
This commit is contained in:
parent
ce91ba7c25
commit
d33dbf82d4
|
@ -22,6 +22,7 @@ import org.elasticsearch.client.ml.inference.preprocessing.PreProcessor;
|
|||
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
@ -38,6 +39,7 @@ public class TrainedModelDefinition implements ToXContentObject {
|
|||
|
||||
public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
|
||||
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
|
||||
public static final ParseField INPUT = new ParseField("input");
|
||||
|
||||
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
|
||||
true,
|
||||
|
@ -51,6 +53,7 @@ public class TrainedModelDefinition implements ToXContentObject {
|
|||
(p, c, n) -> p.namedObject(PreProcessor.class, n, null),
|
||||
(trainedModelDefBuilder) -> {/* Does not matter client side*/ },
|
||||
PREPROCESSORS);
|
||||
PARSER.declareObject(TrainedModelDefinition.Builder::setInput, (p, c) -> Input.fromXContent(p), INPUT);
|
||||
}
|
||||
|
||||
public static TrainedModelDefinition.Builder fromXContent(XContentParser parser) throws IOException {
|
||||
|
@ -59,10 +62,12 @@ public class TrainedModelDefinition implements ToXContentObject {
|
|||
|
||||
private final TrainedModel trainedModel;
|
||||
private final List<PreProcessor> preProcessors;
|
||||
private final Input input;
|
||||
|
||||
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
|
||||
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors, Input input) {
|
||||
this.trainedModel = trainedModel;
|
||||
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
|
||||
this.input = input;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -78,6 +83,9 @@ public class TrainedModelDefinition implements ToXContentObject {
|
|||
true,
|
||||
PREPROCESSORS.getPreferredName(),
|
||||
preProcessors);
|
||||
if (input != null) {
|
||||
builder.field(INPUT.getPreferredName(), input);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -90,6 +98,10 @@ public class TrainedModelDefinition implements ToXContentObject {
|
|||
return preProcessors;
|
||||
}
|
||||
|
||||
public Input getInput() {
|
||||
return input;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Strings.toString(this);
|
||||
|
@ -101,18 +113,20 @@ public class TrainedModelDefinition implements ToXContentObject {
|
|||
if (o == null || getClass() != o.getClass()) return false;
|
||||
TrainedModelDefinition that = (TrainedModelDefinition) o;
|
||||
return Objects.equals(trainedModel, that.trainedModel) &&
|
||||
Objects.equals(preProcessors, that.preProcessors) ;
|
||||
Objects.equals(preProcessors, that.preProcessors) &&
|
||||
Objects.equals(input, that.input);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(trainedModel, preProcessors);
|
||||
return Objects.hash(trainedModel, preProcessors, input);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private List<PreProcessor> preProcessors;
|
||||
private TrainedModel trainedModel;
|
||||
private Input input;
|
||||
|
||||
public Builder setPreProcessors(List<PreProcessor> preProcessors) {
|
||||
this.preProcessors = preProcessors;
|
||||
|
@ -124,14 +138,71 @@ public class TrainedModelDefinition implements ToXContentObject {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setInput(Input input) {
|
||||
this.input = input;
|
||||
return this;
|
||||
}
|
||||
|
||||
private Builder setTrainedModel(List<TrainedModel> trainedModel) {
|
||||
assert trainedModel.size() == 1;
|
||||
return setTrainedModel(trainedModel.get(0));
|
||||
}
|
||||
|
||||
public TrainedModelDefinition build() {
|
||||
return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
|
||||
return new TrainedModelDefinition(this.trainedModel, this.preProcessors, this.input);
|
||||
}
|
||||
}
|
||||
|
||||
public static class Input implements ToXContentObject {
|
||||
|
||||
public static final String NAME = "trained_mode_definition_input";
|
||||
public static final ParseField FIELD_NAMES = new ParseField("field_names");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<Input, Void> PARSER = new ConstructingObjectParser<>(NAME,
|
||||
true,
|
||||
a -> new Input((List<String>)a[0]));
|
||||
static {
|
||||
PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES);
|
||||
}
|
||||
|
||||
public static Input fromXContent(XContentParser parser) throws IOException {
|
||||
return PARSER.parse(parser, null);
|
||||
}
|
||||
|
||||
private final List<String> fieldNames;
|
||||
|
||||
public Input(List<String> fieldNames) {
|
||||
this.fieldNames = fieldNames;
|
||||
}
|
||||
|
||||
public List<String> getFieldNames() {
|
||||
return fieldNames;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (fieldNames != null) {
|
||||
builder.field(FIELD_NAMES.getPreferredName(), fieldNames);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
TrainedModelDefinition.Input that = (TrainedModelDefinition.Input) o;
|
||||
return Objects.equals(fieldNames, that.fieldNames);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(fieldNames);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -39,7 +39,7 @@ public class TargetMeanEncoding implements PreProcessor {
|
|||
public static final String NAME = "target_mean_encoding";
|
||||
public static final ParseField FIELD = new ParseField("field");
|
||||
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
|
||||
public static final ParseField TARGET_MEANS = new ParseField("target_means");
|
||||
public static final ParseField TARGET_MAP = new ParseField("target_map");
|
||||
public static final ParseField DEFAULT_VALUE = new ParseField("default_value");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
|
@ -52,7 +52,7 @@ public class TargetMeanEncoding implements PreProcessor {
|
|||
PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
|
||||
PARSER.declareObject(ConstructingObjectParser.constructorArg(),
|
||||
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
|
||||
TARGET_MEANS);
|
||||
TARGET_MAP);
|
||||
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), DEFAULT_VALUE);
|
||||
}
|
||||
|
||||
|
@ -110,7 +110,7 @@ public class TargetMeanEncoding implements PreProcessor {
|
|||
builder.startObject();
|
||||
builder.field(FIELD.getPreferredName(), field);
|
||||
builder.field(FEATURE_NAME.getPreferredName(), featureName);
|
||||
builder.field(TARGET_MEANS.getPreferredName(), meanMap);
|
||||
builder.field(TARGET_MAP.getPreferredName(), meanMap);
|
||||
builder.field(DEFAULT_VALUE.getPreferredName(), defaultValue);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
|
|
@ -64,7 +64,10 @@ public class TrainedModelDefinitionTests extends AbstractXContentTestCase<Traine
|
|||
TargetMeanEncodingTests.createRandom()))
|
||||
.limit(numberOfProcessors)
|
||||
.collect(Collectors.toList()))
|
||||
.setTrainedModel(randomFrom(TreeTests.createRandom()));
|
||||
.setTrainedModel(randomFrom(TreeTests.createRandom()))
|
||||
.setInput(new TrainedModelDefinition.Input(Stream.generate(() -> randomAlphaOfLength(10))
|
||||
.limit(randomLongBetween(1, 10))
|
||||
.collect(Collectors.toList())));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.common.Strings;
|
|||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
@ -30,10 +31,11 @@ import java.util.Objects;
|
|||
|
||||
public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
||||
|
||||
public static final String NAME = "trained_model_doc";
|
||||
public static final String NAME = "trained_mode_definition";
|
||||
|
||||
public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
|
||||
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
|
||||
public static final ParseField INPUT = new ParseField("input");
|
||||
|
||||
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
|
||||
public static final ObjectParser<TrainedModelDefinition.Builder, Void> LENIENT_PARSER = createParser(true);
|
||||
|
@ -55,6 +57,7 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
|||
p.namedObject(StrictlyParsedPreProcessor.class, n, null),
|
||||
(trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true),
|
||||
PREPROCESSORS);
|
||||
parser.declareObject(TrainedModelDefinition.Builder::setInput, (p, c) -> Input.fromXContent(p, ignoreUnknownFields), INPUT);
|
||||
return parser;
|
||||
}
|
||||
|
||||
|
@ -64,21 +67,25 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
|||
|
||||
private final TrainedModel trainedModel;
|
||||
private final List<PreProcessor> preProcessors;
|
||||
private final Input input;
|
||||
|
||||
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
|
||||
this.trainedModel = trainedModel;
|
||||
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors, Input input) {
|
||||
this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL);
|
||||
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
|
||||
this.input = ExceptionsHelper.requireNonNull(input, INPUT);
|
||||
}
|
||||
|
||||
public TrainedModelDefinition(StreamInput in) throws IOException {
|
||||
this.trainedModel = in.readNamedWriteable(TrainedModel.class);
|
||||
this.preProcessors = in.readNamedWriteableList(PreProcessor.class);
|
||||
this.input = new Input(in);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeNamedWriteable(trainedModel);
|
||||
out.writeNamedWriteableList(preProcessors);
|
||||
input.writeTo(out);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -94,6 +101,7 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
|||
true,
|
||||
PREPROCESSORS.getPreferredName(),
|
||||
preProcessors);
|
||||
builder.field(INPUT.getPreferredName(), input);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -106,6 +114,10 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
|||
return preProcessors;
|
||||
}
|
||||
|
||||
public Input getInput() {
|
||||
return input;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Strings.toString(this);
|
||||
|
@ -117,12 +129,13 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
|||
if (o == null || getClass() != o.getClass()) return false;
|
||||
TrainedModelDefinition that = (TrainedModelDefinition) o;
|
||||
return Objects.equals(trainedModel, that.trainedModel) &&
|
||||
Objects.equals(preProcessors, that.preProcessors) ;
|
||||
Objects.equals(input, that.input) &&
|
||||
Objects.equals(preProcessors, that.preProcessors);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(trainedModel, preProcessors);
|
||||
return Objects.hash(trainedModel, input, preProcessors);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
@ -130,6 +143,7 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
|||
private List<PreProcessor> preProcessors;
|
||||
private TrainedModel trainedModel;
|
||||
private boolean processorsInOrder;
|
||||
private Input input;
|
||||
|
||||
private static Builder builderForParser() {
|
||||
return new Builder(false);
|
||||
|
@ -153,6 +167,11 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setInput(Input input) {
|
||||
this.input = input;
|
||||
return this;
|
||||
}
|
||||
|
||||
private Builder setTrainedModel(List<TrainedModel> trainedModel) {
|
||||
if (trainedModel.size() != 1) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must have exactly one trained model defined.",
|
||||
|
@ -169,8 +188,71 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
|||
if (preProcessors != null && preProcessors.size() > 1 && processorsInOrder == false) {
|
||||
throw new IllegalArgumentException("preprocessors must be an array of preprocessor objects");
|
||||
}
|
||||
return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
|
||||
return new TrainedModelDefinition(this.trainedModel, this.preProcessors, this.input);
|
||||
}
|
||||
}
|
||||
|
||||
public static class Input implements ToXContentObject, Writeable {
|
||||
|
||||
public static final String NAME = "trained_mode_definition_input";
|
||||
public static final ParseField FIELD_NAMES = new ParseField("field_names");
|
||||
|
||||
public static final ConstructingObjectParser<Input, Void> LENIENT_PARSER = createParser(true);
|
||||
public static final ConstructingObjectParser<Input, Void> STRICT_PARSER = createParser(false);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static ConstructingObjectParser<Input, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<Input, Void> parser = new ConstructingObjectParser<>(NAME,
|
||||
ignoreUnknownFields,
|
||||
a -> new Input((List<String>)a[0]));
|
||||
parser.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES);
|
||||
return parser;
|
||||
}
|
||||
|
||||
public static Input fromXContent(XContentParser parser, boolean lenient) throws IOException {
|
||||
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
|
||||
}
|
||||
|
||||
private final List<String> fieldNames;
|
||||
|
||||
public Input(List<String> fieldNames) {
|
||||
this.fieldNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(fieldNames, FIELD_NAMES));
|
||||
}
|
||||
|
||||
public Input(StreamInput in) throws IOException {
|
||||
this.fieldNames = Collections.unmodifiableList(in.readStringList());
|
||||
}
|
||||
|
||||
public List<String> getFieldNames() {
|
||||
return fieldNames;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeStringCollection(fieldNames);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FIELD_NAMES.getPreferredName(), fieldNames);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
TrainedModelDefinition.Input that = (TrainedModelDefinition.Input) o;
|
||||
return Objects.equals(fieldNames, that.fieldNames);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(fieldNames);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
|
|||
public static final ParseField NAME = new ParseField("target_mean_encoding");
|
||||
public static final ParseField FIELD = new ParseField("field");
|
||||
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
|
||||
public static final ParseField TARGET_MEANS = new ParseField("target_means");
|
||||
public static final ParseField TARGET_MAP = new ParseField("target_map");
|
||||
public static final ParseField DEFAULT_VALUE = new ParseField("default_value");
|
||||
|
||||
public static final ConstructingObjectParser<TargetMeanEncoding, Void> STRICT_PARSER = createParser(false);
|
||||
|
@ -44,7 +44,7 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
|
|||
parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(),
|
||||
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
|
||||
TARGET_MEANS);
|
||||
TARGET_MAP);
|
||||
parser.declareDouble(ConstructingObjectParser.constructorArg(), DEFAULT_VALUE);
|
||||
return parser;
|
||||
}
|
||||
|
@ -65,7 +65,7 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
|
|||
public TargetMeanEncoding(String field, String featureName, Map<String, Double> meanMap, Double defaultValue) {
|
||||
this.field = ExceptionsHelper.requireNonNull(field, FIELD);
|
||||
this.featureName = ExceptionsHelper.requireNonNull(featureName, FEATURE_NAME);
|
||||
this.meanMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(meanMap, TARGET_MEANS));
|
||||
this.meanMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(meanMap, TARGET_MAP));
|
||||
this.defaultValue = ExceptionsHelper.requireNonNull(defaultValue, DEFAULT_VALUE);
|
||||
}
|
||||
|
||||
|
@ -136,7 +136,7 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
|
|||
builder.startObject();
|
||||
builder.field(FIELD.getPreferredName(), field);
|
||||
builder.field(FEATURE_NAME.getPreferredName(), featureName);
|
||||
builder.field(TARGET_MEANS.getPreferredName(), meanMap);
|
||||
builder.field(TARGET_MAP.getPreferredName(), meanMap);
|
||||
builder.field(DEFAULT_VALUE.getPreferredName(), defaultValue);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
|
|
@ -107,14 +107,13 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
|
||||
@Override
|
||||
public double infer(Map<String, Object> fields) {
|
||||
List<Double> features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList());
|
||||
return infer(features);
|
||||
List<Double> processedInferences = inferAndProcess(fields);
|
||||
return outputAggregator.aggregate(processedInferences);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double infer(List<Double> fields) {
|
||||
List<Double> processedInferences = inferAndProcess(fields);
|
||||
return outputAggregator.aggregate(processedInferences);
|
||||
throw new UnsupportedOperationException("Ensemble requires map containing field names and values");
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -128,17 +127,12 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
throw new UnsupportedOperationException(
|
||||
"Cannot determine classification probability with target_type [" + targetType.toString() + "]");
|
||||
}
|
||||
List<Double> features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList());
|
||||
return classificationProbability(features);
|
||||
return inferAndProcess(fields);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Double> classificationProbability(List<Double> fields) {
|
||||
if ((targetType == TargetType.CLASSIFICATION) == false) {
|
||||
throw new UnsupportedOperationException(
|
||||
"Cannot determine classification probability with target_type [" + targetType.toString() + "]");
|
||||
}
|
||||
return inferAndProcess(fields);
|
||||
throw new UnsupportedOperationException("Ensemble requires map containing field names and values");
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -146,7 +140,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
return classificationLabels;
|
||||
}
|
||||
|
||||
private List<Double> inferAndProcess(List<Double> fields) {
|
||||
private List<Double> inferAndProcess(Map<String, Object> fields) {
|
||||
List<Double> modelInferences = models.stream().map(m -> m.infer(fields)).collect(Collectors.toList());
|
||||
return outputAggregator.processValues(modelInferences);
|
||||
}
|
||||
|
@ -210,15 +204,6 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
|
||||
@Override
|
||||
public void validate() {
|
||||
if (this.featureNames != null) {
|
||||
if (this.models.stream()
|
||||
.anyMatch(trainedModel -> trainedModel.getFeatureNames().equals(this.featureNames) == false)) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[{}] must be the same and in the same order for each of the {}",
|
||||
FEATURE_NAMES.getPreferredName(),
|
||||
TRAINED_MODELS.getPreferredName());
|
||||
}
|
||||
}
|
||||
if (outputAggregator.expectedValueSize() != null &&
|
||||
outputAggregator.expectedValueSize() != models.size()) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
|
|
|
@ -106,7 +106,9 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
|
||||
@Override
|
||||
public double infer(Map<String, Object> fields) {
|
||||
List<Double> features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList());
|
||||
List<Double> features = featureNames.stream().map(f ->
|
||||
fields.get(f) instanceof Number ? ((Number)fields.get(f)).doubleValue() : null
|
||||
).collect(Collectors.toList());
|
||||
return infer(features);
|
||||
}
|
||||
|
||||
|
@ -146,7 +148,11 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
throw new UnsupportedOperationException(
|
||||
"Cannot determine classification probability with target_type [" + targetType.toString() + "]");
|
||||
}
|
||||
return classificationProbability(featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList()));
|
||||
List<Double> features = featureNames.stream().map(f ->
|
||||
fields.get(f) instanceof Number ? ((Number)fields.get(f)).doubleValue() : null)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
return classificationProbability(features);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -8,13 +8,18 @@ package org.elasticsearch.xpack.core.ml.inference;
|
|||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.DeprecationHandler;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
|
||||
import org.junit.Before;
|
||||
|
||||
|
@ -26,6 +31,8 @@ import java.util.function.Predicate;
|
|||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
|
||||
public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<TrainedModelDefinition> {
|
||||
|
||||
|
@ -61,8 +68,229 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
|
|||
TargetMeanEncodingTests.createRandom()))
|
||||
.limit(numberOfProcessors)
|
||||
.collect(Collectors.toList()))
|
||||
.setInput(new TrainedModelDefinition.Input(Stream.generate(() -> randomAlphaOfLength(10))
|
||||
.limit(randomLongBetween(1, 10))
|
||||
.collect(Collectors.toList())))
|
||||
.setTrainedModel(randomFrom(TreeTests.createRandom()));
|
||||
}
|
||||
|
||||
private static final String ENSEMBLE_MODEL = "" +
|
||||
"{\n" +
|
||||
" \"input\": {\n" +
|
||||
" \"field_names\": [\n" +
|
||||
" \"col1\",\n" +
|
||||
" \"col2\",\n" +
|
||||
" \"col3\",\n" +
|
||||
" \"col4\"\n" +
|
||||
" ]\n" +
|
||||
" },\n" +
|
||||
" \"preprocessors\": [\n" +
|
||||
" {\n" +
|
||||
" \"one_hot_encoding\": {\n" +
|
||||
" \"field\": \"col1\",\n" +
|
||||
" \"hot_map\": {\n" +
|
||||
" \"male\": \"col1_male\",\n" +
|
||||
" \"female\": \"col1_female\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"target_mean_encoding\": {\n" +
|
||||
" \"field\": \"col2\",\n" +
|
||||
" \"feature_name\": \"col2_encoded\",\n" +
|
||||
" \"target_map\": {\n" +
|
||||
" \"S\": 5.0,\n" +
|
||||
" \"M\": 10.0,\n" +
|
||||
" \"L\": 20\n" +
|
||||
" },\n" +
|
||||
" \"default_value\": 5.0\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"frequency_encoding\": {\n" +
|
||||
" \"field\": \"col3\",\n" +
|
||||
" \"feature_name\": \"col3_encoded\",\n" +
|
||||
" \"frequency_map\": {\n" +
|
||||
" \"none\": 0.75,\n" +
|
||||
" \"true\": 0.10,\n" +
|
||||
" \"false\": 0.15\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
" \"trained_model\": {\n" +
|
||||
" \"ensemble\": {\n" +
|
||||
" \"feature_names\": [\n" +
|
||||
" \"col1_male\",\n" +
|
||||
" \"col1_female\",\n" +
|
||||
" \"col2_encoded\",\n" +
|
||||
" \"col3_encoded\",\n" +
|
||||
" \"col4\"\n" +
|
||||
" ],\n" +
|
||||
" \"aggregate_output\": {\n" +
|
||||
" \"weighted_sum\": {\n" +
|
||||
" \"weights\": [\n" +
|
||||
" 0.5,\n" +
|
||||
" 0.5\n" +
|
||||
" ]\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" \"target_type\": \"regression\",\n" +
|
||||
" \"trained_models\": [\n" +
|
||||
" {\n" +
|
||||
" \"tree\": {\n" +
|
||||
" \"feature_names\": [\n" +
|
||||
" \"col1_male\",\n" +
|
||||
" \"col1_female\",\n" +
|
||||
" \"col4\"\n" +
|
||||
" ],\n" +
|
||||
" \"tree_structure\": [\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 0,\n" +
|
||||
" \"split_feature\": 0,\n" +
|
||||
" \"split_gain\": 12.0,\n" +
|
||||
" \"threshold\": 10.0,\n" +
|
||||
" \"decision_type\": \"lte\",\n" +
|
||||
" \"default_left\": true,\n" +
|
||||
" \"left_child\": 1,\n" +
|
||||
" \"right_child\": 2\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 1,\n" +
|
||||
" \"leaf_value\": 1\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 2,\n" +
|
||||
" \"leaf_value\": 2\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
" \"target_type\": \"regression\"\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"tree\": {\n" +
|
||||
" \"feature_names\": [\n" +
|
||||
" \"col2_encoded\",\n" +
|
||||
" \"col3_encoded\",\n" +
|
||||
" \"col4\"\n" +
|
||||
" ],\n" +
|
||||
" \"tree_structure\": [\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 0,\n" +
|
||||
" \"split_feature\": 0,\n" +
|
||||
" \"split_gain\": 12.0,\n" +
|
||||
" \"threshold\": 10.0,\n" +
|
||||
" \"decision_type\": \"lte\",\n" +
|
||||
" \"default_left\": true,\n" +
|
||||
" \"left_child\": 1,\n" +
|
||||
" \"right_child\": 2\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 1,\n" +
|
||||
" \"leaf_value\": 1\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 2,\n" +
|
||||
" \"leaf_value\": 2\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
" \"target_type\": \"regression\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" ]\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
"}";
|
||||
private static final String TREE_MODEL = "" +
|
||||
"{\n" +
|
||||
" \"input\": {\n" +
|
||||
" \"field_names\": [\n" +
|
||||
" \"col1\",\n" +
|
||||
" \"col2\",\n" +
|
||||
" \"col3\",\n" +
|
||||
" \"col4\"\n" +
|
||||
" ]\n" +
|
||||
" },\n" +
|
||||
" \"preprocessors\": [\n" +
|
||||
" {\n" +
|
||||
" \"one_hot_encoding\": {\n" +
|
||||
" \"field\": \"col1\",\n" +
|
||||
" \"hot_map\": {\n" +
|
||||
" \"male\": \"col1_male\",\n" +
|
||||
" \"female\": \"col1_female\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"target_mean_encoding\": {\n" +
|
||||
" \"field\": \"col2\",\n" +
|
||||
" \"feature_name\": \"col2_encoded\",\n" +
|
||||
" \"target_map\": {\n" +
|
||||
" \"S\": 5.0,\n" +
|
||||
" \"M\": 10.0,\n" +
|
||||
" \"L\": 20\n" +
|
||||
" },\n" +
|
||||
" \"default_value\": 5.0\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"frequency_encoding\": {\n" +
|
||||
" \"field\": \"col3\",\n" +
|
||||
" \"feature_name\": \"col3_encoded\",\n" +
|
||||
" \"frequency_map\": {\n" +
|
||||
" \"none\": 0.75,\n" +
|
||||
" \"true\": 0.10,\n" +
|
||||
" \"false\": 0.15\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
" \"trained_model\": {\n" +
|
||||
" \"tree\": {\n" +
|
||||
" \"feature_names\": [\n" +
|
||||
" \"col1_male\",\n" +
|
||||
" \"col1_female\",\n" +
|
||||
" \"col4\"\n" +
|
||||
" ],\n" +
|
||||
" \"tree_structure\": [\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 0,\n" +
|
||||
" \"split_feature\": 0,\n" +
|
||||
" \"split_gain\": 12.0,\n" +
|
||||
" \"threshold\": 10.0,\n" +
|
||||
" \"decision_type\": \"lte\",\n" +
|
||||
" \"default_left\": true,\n" +
|
||||
" \"left_child\": 1,\n" +
|
||||
" \"right_child\": 2\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 1,\n" +
|
||||
" \"leaf_value\": 1\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 2,\n" +
|
||||
" \"leaf_value\": 2\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
" \"target_type\": \"regression\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
"}";
|
||||
|
||||
public void testEnsembleSchemaDeserialization() throws IOException {
|
||||
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
|
||||
.createParser(xContentRegistry(), DeprecationHandler.THROW_UNSUPPORTED_OPERATION, ENSEMBLE_MODEL);
|
||||
TrainedModelDefinition definition = TrainedModelDefinition.fromXContent(parser, false).build();
|
||||
assertThat(definition.getTrainedModel().getClass(), equalTo(Ensemble.class));
|
||||
}
|
||||
|
||||
public void testTreeSchemaDeserialization() throws IOException {
|
||||
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
|
||||
.createParser(xContentRegistry(), DeprecationHandler.THROW_UNSUPPORTED_OPERATION, TREE_MODEL);
|
||||
TrainedModelDefinition definition = TrainedModelDefinition.fromXContent(parser, false).build();
|
||||
assertThat(definition.getTrainedModel().getClass(), equalTo(Tree.class));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TrainedModelDefinition createTestInstance() {
|
||||
return createRandomBuilder().build();
|
||||
|
|
|
@ -26,6 +26,7 @@ import java.io.IOException;
|
|||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Predicate;
|
||||
|
@ -108,25 +109,6 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
return new NamedWriteableRegistry(entries);
|
||||
}
|
||||
|
||||
public void testEnsembleWithModelsThatHaveDifferentFeatureNames() {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar", "baz", "farequote");
|
||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
|
||||
Ensemble.builder().setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(TreeTests.buildRandomTree(Arrays.asList("bar", "foo", "baz", "farequote"), 6)))
|
||||
.build()
|
||||
.validate();
|
||||
});
|
||||
assertThat(ex.getMessage(), equalTo("[feature_names] must be the same and in the same order for each of the trained_models"));
|
||||
|
||||
ex = expectThrows(ElasticsearchException.class, () -> {
|
||||
Ensemble.builder().setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(TreeTests.buildRandomTree(Arrays.asList("completely_different"), 6)))
|
||||
.build()
|
||||
.validate();
|
||||
});
|
||||
assertThat(ex.getMessage(), equalTo("[feature_names] must be the same and in the same order for each of the trained_models"));
|
||||
}
|
||||
|
||||
public void testEnsembleWithAggregatedOutputDifferingFromTrainedModels() {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
int numberOfModels = 5;
|
||||
|
@ -279,6 +261,17 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
for(int i = 0; i < expected.size(); i++) {
|
||||
assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
|
||||
}
|
||||
|
||||
// This should handle missing values and take the default_left path
|
||||
featureMap = new HashMap<String, Object>(2) {{
|
||||
put("foo", 0.3);
|
||||
put("bar", null);
|
||||
}};
|
||||
expected = Arrays.asList(0.6899744811, 0.3100255188);
|
||||
probabilities = ensemble.classificationProbability(featureMap);
|
||||
for(int i = 0; i < expected.size(); i++) {
|
||||
assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
|
||||
}
|
||||
}
|
||||
|
||||
public void testClassificationInference() {
|
||||
|
@ -336,6 +329,12 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
featureVector = Arrays.asList(0.0, 1.0);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
|
||||
|
||||
featureMap = new HashMap<String, Object>(2) {{
|
||||
put("foo", 0.3);
|
||||
put("bar", null);
|
||||
}};
|
||||
assertEquals(0.0, ensemble.infer(featureMap), 0.00001);
|
||||
}
|
||||
|
||||
public void testRegressionInference() {
|
||||
|
@ -394,6 +393,12 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
featureVector = Arrays.asList(2.0, 0.7);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
|
||||
|
||||
featureMap = new HashMap<String, Object>(2) {{
|
||||
put("foo", 0.3);
|
||||
put("bar", null);
|
||||
}};
|
||||
assertEquals(1.8, ensemble.infer(featureMap), 0.00001);
|
||||
}
|
||||
|
||||
private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
|
||||
|
|
|
@ -17,12 +17,14 @@ import java.io.IOException;
|
|||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
|
||||
|
@ -118,19 +120,26 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
|||
// This feature vector should hit the right child of the root node
|
||||
List<Double> featureVector = Arrays.asList(0.6, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(0.3, tree.infer(featureMap), 0.00001);
|
||||
assertThat(0.3, closeTo(tree.infer(featureMap), 0.00001));
|
||||
|
||||
// This should hit the left child of the left child of the root node
|
||||
// i.e. it takes the path left, left
|
||||
featureVector = Arrays.asList(0.3, 0.7);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(0.1, tree.infer(featureMap), 0.00001);
|
||||
assertThat(0.1, closeTo(tree.infer(featureMap), 0.00001));
|
||||
|
||||
// This should hit the right child of the left child of the root node
|
||||
// i.e. it takes the path left, right
|
||||
featureVector = Arrays.asList(0.3, 0.9);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(0.2, tree.infer(featureMap), 0.00001);
|
||||
assertThat(0.2, closeTo(tree.infer(featureMap), 0.00001));
|
||||
|
||||
// This should handle missing values and take the default_left path
|
||||
featureMap = new HashMap<String, Object>(2) {{
|
||||
put("foo", 0.3);
|
||||
put("bar", null);
|
||||
}};
|
||||
assertThat(0.1, closeTo(tree.infer(featureMap), 0.00001));
|
||||
}
|
||||
|
||||
public void testTreeClassificationProbability() {
|
||||
|
@ -162,6 +171,13 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
|||
featureVector = Arrays.asList(0.3, 0.9);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(Arrays.asList(1.0, 0.0), tree.classificationProbability(featureMap));
|
||||
|
||||
// This should handle missing values and take the default_left path
|
||||
featureMap = new HashMap<String, Object>(2) {{
|
||||
put("foo", 0.3);
|
||||
put("bar", null);
|
||||
}};
|
||||
assertEquals(1.0, tree.infer(featureMap), 0.00001);
|
||||
}
|
||||
|
||||
public void testTreeWithNullRoot() {
|
||||
|
|
Loading…
Reference in New Issue