* [ML][Inference] adding ensemble model objects * addressing PR comments * Update TreeTests.java * addressing PR comments * fixing test
This commit is contained in:
parent
b9541eb3af
commit
2228a7dd8d
|
@ -19,6 +19,10 @@
|
|||
package org.elasticsearch.client.ml.inference;
|
||||
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
|
||||
|
@ -47,6 +51,15 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
|
|||
|
||||
// Model
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Tree.NAME), Tree::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Ensemble.NAME), Ensemble::fromXContent));
|
||||
|
||||
// Aggregating output
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class,
|
||||
new ParseField(WeightedMode.NAME),
|
||||
WeightedMode::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class,
|
||||
new ParseField(WeightedSum.NAME),
|
||||
WeightedSum::fromXContent));
|
||||
|
||||
return namedXContent;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.trainedmodel;
|
||||
|
||||
import java.util.Locale;
|
||||
|
||||
public enum TargetType {
|
||||
|
||||
REGRESSION, CLASSIFICATION;
|
||||
|
||||
public static TargetType fromString(String name) {
|
||||
return valueOf(name.trim().toUpperCase(Locale.ROOT));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return name().toLowerCase(Locale.ROOT);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,188 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
import org.elasticsearch.client.ml.inference.NamedXContentObjectHelper;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public class Ensemble implements TrainedModel {
|
||||
|
||||
public static final String NAME = "ensemble";
|
||||
public static final ParseField FEATURE_NAMES = new ParseField("feature_names");
|
||||
public static final ParseField TRAINED_MODELS = new ParseField("trained_models");
|
||||
public static final ParseField AGGREGATE_OUTPUT = new ParseField("aggregate_output");
|
||||
public static final ParseField TARGET_TYPE = new ParseField("target_type");
|
||||
public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
|
||||
|
||||
private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(
|
||||
NAME,
|
||||
true,
|
||||
Ensemble.Builder::new);
|
||||
|
||||
static {
|
||||
PARSER.declareStringArray(Ensemble.Builder::setFeatureNames, FEATURE_NAMES);
|
||||
PARSER.declareNamedObjects(Ensemble.Builder::setTrainedModels,
|
||||
(p, c, n) ->
|
||||
p.namedObject(TrainedModel.class, n, null),
|
||||
(ensembleBuilder) -> { /* Noop does not matter client side */ },
|
||||
TRAINED_MODELS);
|
||||
PARSER.declareNamedObjects(Ensemble.Builder::setOutputAggregatorFromParser,
|
||||
(p, c, n) -> p.namedObject(OutputAggregator.class, n, null),
|
||||
(ensembleBuilder) -> { /* Noop does not matter client side */ },
|
||||
AGGREGATE_OUTPUT);
|
||||
PARSER.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE);
|
||||
PARSER.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
|
||||
}
|
||||
|
||||
public static Ensemble fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null).build();
|
||||
}
|
||||
|
||||
private final List<String> featureNames;
|
||||
private final List<TrainedModel> models;
|
||||
private final OutputAggregator outputAggregator;
|
||||
private final TargetType targetType;
|
||||
private final List<String> classificationLabels;
|
||||
|
||||
Ensemble(List<String> featureNames,
|
||||
List<TrainedModel> models,
|
||||
@Nullable OutputAggregator outputAggregator,
|
||||
TargetType targetType,
|
||||
@Nullable List<String> classificationLabels) {
|
||||
this.featureNames = featureNames;
|
||||
this.models = models;
|
||||
this.outputAggregator = outputAggregator;
|
||||
this.targetType = targetType;
|
||||
this.classificationLabels = classificationLabels;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getFeatureNames() {
|
||||
return featureNames;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (featureNames != null) {
|
||||
builder.field(FEATURE_NAMES.getPreferredName(), featureNames);
|
||||
}
|
||||
if (models != null) {
|
||||
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, TRAINED_MODELS.getPreferredName(), models);
|
||||
}
|
||||
if (outputAggregator != null) {
|
||||
NamedXContentObjectHelper.writeNamedObjects(builder,
|
||||
params,
|
||||
false,
|
||||
AGGREGATE_OUTPUT.getPreferredName(),
|
||||
Collections.singletonList(outputAggregator));
|
||||
}
|
||||
if (targetType != null) {
|
||||
builder.field(TARGET_TYPE.getPreferredName(), targetType);
|
||||
}
|
||||
if (classificationLabels != null) {
|
||||
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Ensemble that = (Ensemble) o;
|
||||
return Objects.equals(featureNames, that.featureNames)
|
||||
&& Objects.equals(models, that.models)
|
||||
&& Objects.equals(targetType, that.targetType)
|
||||
&& Objects.equals(classificationLabels, that.classificationLabels)
|
||||
&& Objects.equals(outputAggregator, that.outputAggregator);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(featureNames, models, outputAggregator, classificationLabels, targetType);
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private List<String> featureNames;
|
||||
private List<TrainedModel> trainedModels;
|
||||
private OutputAggregator outputAggregator;
|
||||
private TargetType targetType;
|
||||
private List<String> classificationLabels;
|
||||
|
||||
public Builder setFeatureNames(List<String> featureNames) {
|
||||
this.featureNames = featureNames;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setTrainedModels(List<TrainedModel> trainedModels) {
|
||||
this.trainedModels = trainedModels;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setOutputAggregator(OutputAggregator outputAggregator) {
|
||||
this.outputAggregator = outputAggregator;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setTargetType(TargetType targetType) {
|
||||
this.targetType = targetType;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setClassificationLabels(List<String> classificationLabels) {
|
||||
this.classificationLabels = classificationLabels;
|
||||
return this;
|
||||
}
|
||||
|
||||
private void setOutputAggregatorFromParser(List<OutputAggregator> outputAggregators) {
|
||||
this.setOutputAggregator(outputAggregators.get(0));
|
||||
}
|
||||
|
||||
private void setTargetType(String targetType) {
|
||||
this.targetType = TargetType.fromString(targetType);
|
||||
}
|
||||
|
||||
public Ensemble build() {
|
||||
return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
import org.elasticsearch.client.ml.inference.NamedXContentObject;
|
||||
|
||||
public interface OutputAggregator extends NamedXContentObject {
|
||||
/**
|
||||
* @return The name of the output aggregator
|
||||
*/
|
||||
String getName();
|
||||
}
|
|
@ -0,0 +1,84 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
|
||||
public class WeightedMode implements OutputAggregator {
|
||||
|
||||
public static final String NAME = "weighted_mode";
|
||||
public static final ParseField WEIGHTS = new ParseField("weights");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<WeightedMode, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME,
|
||||
true,
|
||||
a -> new WeightedMode((List<Double>)a[0]));
|
||||
static {
|
||||
PARSER.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
|
||||
}
|
||||
|
||||
public static WeightedMode fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final List<Double> weights;
|
||||
|
||||
public WeightedMode(List<Double> weights) {
|
||||
this.weights = weights;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (weights != null) {
|
||||
builder.field(WEIGHTS.getPreferredName(), weights);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
WeightedMode that = (WeightedMode) o;
|
||||
return Objects.equals(weights, that.weights);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(weights);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,84 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public class WeightedSum implements OutputAggregator {
|
||||
|
||||
public static final String NAME = "weighted_sum";
|
||||
public static final ParseField WEIGHTS = new ParseField("weights");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<WeightedSum, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME,
|
||||
true,
|
||||
a -> new WeightedSum((List<Double>)a[0]));
|
||||
|
||||
static {
|
||||
PARSER.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
|
||||
}
|
||||
|
||||
public static WeightedSum fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final List<Double> weights;
|
||||
|
||||
public WeightedSum(List<Double> weights) {
|
||||
this.weights = weights;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (weights != null) {
|
||||
builder.field(WEIGHTS.getPreferredName(), weights);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
WeightedSum that = (WeightedSum) o;
|
||||
return Objects.equals(weights, that.weights);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(weights);
|
||||
}
|
||||
}
|
|
@ -18,7 +18,9 @@
|
|||
*/
|
||||
package org.elasticsearch.client.ml.inference.trainedmodel.tree;
|
||||
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
|
@ -28,7 +30,6 @@ import org.elasticsearch.common.xcontent.XContentParser;
|
|||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
@ -39,12 +40,16 @@ public class Tree implements TrainedModel {
|
|||
|
||||
public static final ParseField FEATURE_NAMES = new ParseField("feature_names");
|
||||
public static final ParseField TREE_STRUCTURE = new ParseField("tree_structure");
|
||||
public static final ParseField TARGET_TYPE = new ParseField("target_type");
|
||||
public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
|
||||
|
||||
private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME, true, Builder::new);
|
||||
|
||||
static {
|
||||
PARSER.declareStringArray(Builder::setFeatureNames, FEATURE_NAMES);
|
||||
PARSER.declareObjectArray(Builder::setNodes, (p, c) -> TreeNode.fromXContent(p), TREE_STRUCTURE);
|
||||
PARSER.declareString(Builder::setTargetType, TARGET_TYPE);
|
||||
PARSER.declareStringArray(Builder::setClassificationLabels, CLASSIFICATION_LABELS);
|
||||
}
|
||||
|
||||
public static Tree fromXContent(XContentParser parser) {
|
||||
|
@ -53,10 +58,14 @@ public class Tree implements TrainedModel {
|
|||
|
||||
private final List<String> featureNames;
|
||||
private final List<TreeNode> nodes;
|
||||
private final TargetType targetType;
|
||||
private final List<String> classificationLabels;
|
||||
|
||||
Tree(List<String> featureNames, List<TreeNode> nodes) {
|
||||
this.featureNames = Collections.unmodifiableList(Objects.requireNonNull(featureNames));
|
||||
this.nodes = Collections.unmodifiableList(Objects.requireNonNull(nodes));
|
||||
Tree(List<String> featureNames, List<TreeNode> nodes, TargetType targetType, List<String> classificationLabels) {
|
||||
this.featureNames = featureNames;
|
||||
this.nodes = nodes;
|
||||
this.targetType = targetType;
|
||||
this.classificationLabels = classificationLabels;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -73,11 +82,30 @@ public class Tree implements TrainedModel {
|
|||
return nodes;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public List<String> getClassificationLabels() {
|
||||
return classificationLabels;
|
||||
}
|
||||
|
||||
public TargetType getTargetType() {
|
||||
return targetType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FEATURE_NAMES.getPreferredName(), featureNames);
|
||||
builder.field(TREE_STRUCTURE.getPreferredName(), nodes);
|
||||
if (featureNames != null) {
|
||||
builder.field(FEATURE_NAMES.getPreferredName(), featureNames);
|
||||
}
|
||||
if (nodes != null) {
|
||||
builder.field(TREE_STRUCTURE.getPreferredName(), nodes);
|
||||
}
|
||||
if (classificationLabels != null) {
|
||||
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
|
||||
}
|
||||
if (targetType != null) {
|
||||
builder.field(TARGET_TYPE.getPreferredName(), targetType.toString());
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -93,12 +121,14 @@ public class Tree implements TrainedModel {
|
|||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Tree that = (Tree) o;
|
||||
return Objects.equals(featureNames, that.featureNames)
|
||||
&& Objects.equals(classificationLabels, that.classificationLabels)
|
||||
&& Objects.equals(targetType, that.targetType)
|
||||
&& Objects.equals(nodes, that.nodes);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(featureNames, nodes);
|
||||
return Objects.hash(featureNames, nodes, targetType, classificationLabels);
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
|
@ -109,6 +139,8 @@ public class Tree implements TrainedModel {
|
|||
private List<String> featureNames;
|
||||
private ArrayList<TreeNode.Builder> nodes;
|
||||
private int numNodes;
|
||||
private TargetType targetType;
|
||||
private List<String> classificationLabels;
|
||||
|
||||
public Builder() {
|
||||
nodes = new ArrayList<>();
|
||||
|
@ -137,6 +169,20 @@ public class Tree implements TrainedModel {
|
|||
return setNodes(Arrays.asList(nodes));
|
||||
}
|
||||
|
||||
public Builder setTargetType(TargetType targetType) {
|
||||
this.targetType = targetType;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setClassificationLabels(List<String> classificationLabels) {
|
||||
this.classificationLabels = classificationLabels;
|
||||
return this;
|
||||
}
|
||||
|
||||
private void setTargetType(String targetType) {
|
||||
this.targetType = TargetType.fromString(targetType);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a decision node. Space for the child nodes is allocated
|
||||
* @param nodeIndex Where to place the node. This is either 0 (root) or an existing child node index
|
||||
|
@ -185,7 +231,9 @@ public class Tree implements TrainedModel {
|
|||
|
||||
public Tree build() {
|
||||
return new Tree(featureNames,
|
||||
nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList()));
|
||||
nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList()),
|
||||
targetType,
|
||||
classificationLabels);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -65,6 +65,9 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Binar
|
|||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
|
||||
|
@ -681,7 +684,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
|||
|
||||
public void testProvidedNamedXContents() {
|
||||
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
|
||||
assertEquals(41, namedXContents.size());
|
||||
assertEquals(44, namedXContents.size());
|
||||
Map<Class<?>, Integer> categories = new HashMap<>();
|
||||
List<String> names = new ArrayList<>();
|
||||
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
|
||||
|
@ -691,7 +694,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
|||
categories.put(namedXContent.categoryClass, counter + 1);
|
||||
}
|
||||
}
|
||||
assertEquals("Had: " + categories, 11, categories.size());
|
||||
assertEquals("Had: " + categories, 12, categories.size());
|
||||
assertEquals(Integer.valueOf(3), categories.get(Aggregation.class));
|
||||
assertTrue(names.contains(ChildrenAggregationBuilder.NAME));
|
||||
assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME));
|
||||
|
@ -740,8 +743,11 @@ public class RestHighLevelClientTests extends ESTestCase {
|
|||
RSquaredMetric.NAME));
|
||||
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
|
||||
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME));
|
||||
assertEquals(Integer.valueOf(1), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
|
||||
assertThat(names, hasItems(Tree.NAME));
|
||||
assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
|
||||
assertThat(names, hasItems(Tree.NAME, Ensemble.NAME));
|
||||
assertEquals(Integer.valueOf(2),
|
||||
categories.get(org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator.class));
|
||||
assertThat(names, hasItems(WeightedMode.NAME, WeightedSum.NAME));
|
||||
}
|
||||
|
||||
public void testApiNamingConventions() throws Exception {
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeTests;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
||||
public class EnsembleTests extends AbstractXContentTestCase<Ensemble> {
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||
return field -> !field.isEmpty();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Ensemble doParseInstance(XContentParser parser) throws IOException {
|
||||
return Ensemble.fromXContent(parser);
|
||||
}
|
||||
|
||||
public static Ensemble createRandom() {
|
||||
int numberOfFeatures = randomIntBetween(1, 10);
|
||||
List<String> featureNames = Stream.generate(() -> randomAlphaOfLength(10))
|
||||
.limit(numberOfFeatures)
|
||||
.collect(Collectors.toList());
|
||||
int numberOfModels = randomIntBetween(1, 10);
|
||||
List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6))
|
||||
.limit(numberOfFeatures)
|
||||
.collect(Collectors.toList());
|
||||
OutputAggregator outputAggregator = null;
|
||||
if (randomBoolean()) {
|
||||
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
|
||||
outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights));
|
||||
}
|
||||
List<String> categoryLabels = null;
|
||||
if (randomBoolean()) {
|
||||
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
|
||||
}
|
||||
return new Ensemble(featureNames,
|
||||
models,
|
||||
outputAggregator,
|
||||
randomFrom(TargetType.values()),
|
||||
categoryLabels);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Ensemble createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
||||
public class WeightedModeTests extends AbstractXContentTestCase<WeightedMode> {
|
||||
|
||||
WeightedMode createTestInstance(int numberOfWeights) {
|
||||
return new WeightedMode(Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected WeightedMode doParseInstance(XContentParser parser) throws IOException {
|
||||
return WeightedMode.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected WeightedMode createTestInstance() {
|
||||
return randomBoolean() ? new WeightedMode(null) : createTestInstance(randomIntBetween(1, 100));
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
||||
public class WeightedSumTests extends AbstractXContentTestCase<WeightedSum> {
|
||||
|
||||
WeightedSum createTestInstance(int numberOfWeights) {
|
||||
return new WeightedSum(Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected WeightedSum doParseInstance(XContentParser parser) throws IOException {
|
||||
return WeightedSum.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected WeightedSum createTestInstance() {
|
||||
return randomBoolean() ? new WeightedSum(null) : createTestInstance(randomIntBetween(1, 100));
|
||||
}
|
||||
|
||||
}
|
|
@ -18,6 +18,7 @@
|
|||
*/
|
||||
package org.elasticsearch.client.ml.inference.trainedmodel.tree;
|
||||
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
|
@ -51,16 +52,17 @@ public class TreeTests extends AbstractXContentTestCase<Tree> {
|
|||
}
|
||||
|
||||
public static Tree createRandom() {
|
||||
return buildRandomTree(randomIntBetween(2, 15), 6);
|
||||
}
|
||||
|
||||
public static Tree buildRandomTree(int numFeatures, int depth) {
|
||||
|
||||
Tree.Builder builder = Tree.builder();
|
||||
List<String> featureNames = new ArrayList<>(numFeatures);
|
||||
for(int i = 0; i < numFeatures; i++) {
|
||||
int numberOfFeatures = randomIntBetween(1, 10);
|
||||
List<String> featureNames = new ArrayList<>();
|
||||
for (int i = 0; i < numberOfFeatures; i++) {
|
||||
featureNames.add(randomAlphaOfLength(10));
|
||||
}
|
||||
return buildRandomTree(featureNames, 6);
|
||||
}
|
||||
|
||||
public static Tree buildRandomTree(List<String> featureNames, int depth) {
|
||||
int numFeatures = featureNames.size();
|
||||
Tree.Builder builder = Tree.builder();
|
||||
builder.setFeatureNames(featureNames);
|
||||
|
||||
TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble());
|
||||
|
@ -81,8 +83,13 @@ public class TreeTests extends AbstractXContentTestCase<Tree> {
|
|||
}
|
||||
childNodes = nextNodes;
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
List<String> categoryLabels = null;
|
||||
if (randomBoolean()) {
|
||||
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
|
||||
}
|
||||
return builder.setClassificationLabels(categoryLabels)
|
||||
.setTargetType(randomFrom(TargetType.values()))
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -11,6 +11,12 @@ import org.elasticsearch.plugins.spi.NamedXContentProvider;
|
|||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LenientlyParsedOutputAggregator;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.StrictlyParsedOutputAggregator;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
|
||||
|
@ -46,9 +52,27 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
|
|||
|
||||
// Model Lenient
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentLenient));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Ensemble.NAME, Ensemble::fromXContentLenient));
|
||||
|
||||
// Output Aggregator Lenient
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedOutputAggregator.class,
|
||||
WeightedMode.NAME,
|
||||
WeightedMode::fromXContentLenient));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedOutputAggregator.class,
|
||||
WeightedSum.NAME,
|
||||
WeightedSum::fromXContentLenient));
|
||||
|
||||
// Model Strict
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentStrict));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedTrainedModel.class, Ensemble.NAME, Ensemble::fromXContentStrict));
|
||||
|
||||
// Output Aggregator Strict
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedOutputAggregator.class,
|
||||
WeightedMode.NAME,
|
||||
WeightedMode::fromXContentStrict));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedOutputAggregator.class,
|
||||
WeightedSum.NAME,
|
||||
WeightedSum::fromXContentStrict));
|
||||
|
||||
return namedXContent;
|
||||
}
|
||||
|
@ -66,6 +90,15 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
|
|||
|
||||
// Model
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModel.class, Tree.NAME.getPreferredName(), Tree::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModel.class, Ensemble.NAME.getPreferredName(), Ensemble::new));
|
||||
|
||||
// Output Aggregator
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(OutputAggregator.class,
|
||||
WeightedSum.NAME.getPreferredName(),
|
||||
WeightedSum::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(OutputAggregator.class,
|
||||
WeightedMode.NAME.getPreferredName(),
|
||||
WeightedMode::new));
|
||||
|
||||
return namedWriteables;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Locale;
|
||||
|
||||
public enum TargetType implements Writeable {
|
||||
|
||||
REGRESSION, CLASSIFICATION;
|
||||
|
||||
public static TargetType fromString(String name) {
|
||||
return valueOf(name.trim().toUpperCase(Locale.ROOT));
|
||||
}
|
||||
|
||||
public static TargetType fromStream(StreamInput in) throws IOException {
|
||||
return in.readEnum(TargetType.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeEnum(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return name().toLowerCase(Locale.ROOT);
|
||||
}
|
||||
}
|
|
@ -5,6 +5,7 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
|
||||
|
||||
|
@ -28,17 +29,47 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable {
|
|||
double infer(Map<String, Object> fields);
|
||||
|
||||
/**
|
||||
* @return {@code true} if the model is classification, {@code false} otherwise.
|
||||
* @param fields similar to {@link TrainedModel#infer(Map)}, but fields are already in order and doubles
|
||||
* @return The predicted value.
|
||||
*/
|
||||
boolean isClassification();
|
||||
double infer(List<Double> fields);
|
||||
|
||||
/**
|
||||
* @return {@link TargetType} for the model.
|
||||
*/
|
||||
TargetType targetType();
|
||||
|
||||
/**
|
||||
* This gathers the probabilities for each potential classification value.
|
||||
*
|
||||
* The probabilities are indexed by classification ordinal label encoding.
|
||||
* The length of this list is equal to the number of classification labels.
|
||||
*
|
||||
* This only should return if the implementation model is inferring classification values and not regression
|
||||
* @param fields The fields and their values to infer against
|
||||
* @return The probabilities of each classification value
|
||||
*/
|
||||
List<Double> inferProbabilities(Map<String, Object> fields);
|
||||
List<Double> classificationProbability(Map<String, Object> fields);
|
||||
|
||||
/**
|
||||
* @param fields similar to {@link TrainedModel#classificationProbability(Map)} but the fields are already in order and doubles
|
||||
* @return The probabilities of each classification value
|
||||
*/
|
||||
List<Double> classificationProbability(List<Double> fields);
|
||||
|
||||
/**
|
||||
* The ordinal encoded list of the classification labels.
|
||||
* @return Oridinal encoded list of classification labels.
|
||||
*/
|
||||
@Nullable
|
||||
List<String> classificationLabels();
|
||||
|
||||
/**
|
||||
* Runs validations against the model.
|
||||
*
|
||||
* Example: {@link org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree} should check if there are any loops
|
||||
*
|
||||
* @throws org.elasticsearch.ElasticsearchException if validations fail
|
||||
*/
|
||||
void validate();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,311 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel {
|
||||
|
||||
// TODO should we have regression/classification sub-classes that accept the builder?
|
||||
public static final ParseField NAME = new ParseField("ensemble");
|
||||
public static final ParseField FEATURE_NAMES = new ParseField("feature_names");
|
||||
public static final ParseField TRAINED_MODELS = new ParseField("trained_models");
|
||||
public static final ParseField AGGREGATE_OUTPUT = new ParseField("aggregate_output");
|
||||
public static final ParseField TARGET_TYPE = new ParseField("target_type");
|
||||
public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
|
||||
|
||||
private static final ObjectParser<Ensemble.Builder, Void> LENIENT_PARSER = createParser(true);
|
||||
private static final ObjectParser<Ensemble.Builder, Void> STRICT_PARSER = createParser(false);
|
||||
|
||||
private static ObjectParser<Ensemble.Builder, Void> createParser(boolean lenient) {
|
||||
ObjectParser<Ensemble.Builder, Void> parser = new ObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
lenient,
|
||||
Ensemble.Builder::builderForParser);
|
||||
parser.declareStringArray(Ensemble.Builder::setFeatureNames, FEATURE_NAMES);
|
||||
parser.declareNamedObjects(Ensemble.Builder::setTrainedModels,
|
||||
(p, c, n) ->
|
||||
lenient ? p.namedObject(LenientlyParsedTrainedModel.class, n, null) :
|
||||
p.namedObject(StrictlyParsedTrainedModel.class, n, null),
|
||||
(ensembleBuilder) -> ensembleBuilder.setModelsAreOrdered(true),
|
||||
TRAINED_MODELS);
|
||||
parser.declareNamedObjects(Ensemble.Builder::setOutputAggregatorFromParser,
|
||||
(p, c, n) ->
|
||||
lenient ? p.namedObject(LenientlyParsedOutputAggregator.class, n, null) :
|
||||
p.namedObject(StrictlyParsedOutputAggregator.class, n, null),
|
||||
(ensembleBuilder) -> {/*Noop as it could be an array or object, it just has to be a one*/},
|
||||
AGGREGATE_OUTPUT);
|
||||
parser.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE);
|
||||
parser.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
|
||||
return parser;
|
||||
}
|
||||
|
||||
public static Ensemble fromXContentStrict(XContentParser parser) {
|
||||
return STRICT_PARSER.apply(parser, null).build();
|
||||
}
|
||||
|
||||
public static Ensemble fromXContentLenient(XContentParser parser) {
|
||||
return LENIENT_PARSER.apply(parser, null).build();
|
||||
}
|
||||
|
||||
private final List<String> featureNames;
|
||||
private final List<TrainedModel> models;
|
||||
private final OutputAggregator outputAggregator;
|
||||
private final TargetType targetType;
|
||||
private final List<String> classificationLabels;
|
||||
|
||||
Ensemble(List<String> featureNames,
|
||||
List<TrainedModel> models,
|
||||
OutputAggregator outputAggregator,
|
||||
TargetType targetType,
|
||||
@Nullable List<String> classificationLabels) {
|
||||
this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
|
||||
this.models = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(models, TRAINED_MODELS));
|
||||
this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT);
|
||||
this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
|
||||
this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
|
||||
}
|
||||
|
||||
public Ensemble(StreamInput in) throws IOException {
|
||||
this.featureNames = Collections.unmodifiableList(in.readStringList());
|
||||
this.models = Collections.unmodifiableList(in.readNamedWriteableList(TrainedModel.class));
|
||||
this.outputAggregator = in.readNamedWriteable(OutputAggregator.class);
|
||||
this.targetType = TargetType.fromStream(in);
|
||||
if (in.readBoolean()) {
|
||||
this.classificationLabels = in.readStringList();
|
||||
} else {
|
||||
this.classificationLabels = null;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getFeatureNames() {
|
||||
return featureNames;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double infer(Map<String, Object> fields) {
|
||||
List<Double> features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList());
|
||||
return infer(features);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double infer(List<Double> fields) {
|
||||
List<Double> processedInferences = inferAndProcess(fields);
|
||||
return outputAggregator.aggregate(processedInferences);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TargetType targetType() {
|
||||
return targetType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Double> classificationProbability(Map<String, Object> fields) {
|
||||
if ((targetType == TargetType.CLASSIFICATION) == false) {
|
||||
throw new UnsupportedOperationException(
|
||||
"Cannot determine classification probability with target_type [" + targetType.toString() + "]");
|
||||
}
|
||||
List<Double> features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList());
|
||||
return classificationProbability(features);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Double> classificationProbability(List<Double> fields) {
|
||||
if ((targetType == TargetType.CLASSIFICATION) == false) {
|
||||
throw new UnsupportedOperationException(
|
||||
"Cannot determine classification probability with target_type [" + targetType.toString() + "]");
|
||||
}
|
||||
return inferAndProcess(fields);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> classificationLabels() {
|
||||
return classificationLabels;
|
||||
}
|
||||
|
||||
private List<Double> inferAndProcess(List<Double> fields) {
|
||||
List<Double> modelInferences = models.stream().map(m -> m.infer(fields)).collect(Collectors.toList());
|
||||
return outputAggregator.processValues(modelInferences);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeStringCollection(featureNames);
|
||||
out.writeNamedWriteableList(models);
|
||||
out.writeNamedWriteable(outputAggregator);
|
||||
targetType.writeTo(out);
|
||||
out.writeBoolean(classificationLabels != null);
|
||||
if (classificationLabels != null) {
|
||||
out.writeStringCollection(classificationLabels);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FEATURE_NAMES.getPreferredName(), featureNames);
|
||||
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, TRAINED_MODELS.getPreferredName(), models);
|
||||
NamedXContentObjectHelper.writeNamedObjects(builder,
|
||||
params,
|
||||
false,
|
||||
AGGREGATE_OUTPUT.getPreferredName(),
|
||||
Collections.singletonList(outputAggregator));
|
||||
builder.field(TARGET_TYPE.getPreferredName(), targetType.toString());
|
||||
if (classificationLabels != null) {
|
||||
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Ensemble that = (Ensemble) o;
|
||||
return Objects.equals(featureNames, that.featureNames)
|
||||
&& Objects.equals(models, that.models)
|
||||
&& Objects.equals(targetType, that.targetType)
|
||||
&& Objects.equals(classificationLabels, that.classificationLabels)
|
||||
&& Objects.equals(outputAggregator, that.outputAggregator);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(featureNames, models, outputAggregator, targetType, classificationLabels);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void validate() {
|
||||
if (this.featureNames != null) {
|
||||
if (this.models.stream()
|
||||
.anyMatch(trainedModel -> trainedModel.getFeatureNames().equals(this.featureNames) == false)) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[{}] must be the same and in the same order for each of the {}",
|
||||
FEATURE_NAMES.getPreferredName(),
|
||||
TRAINED_MODELS.getPreferredName());
|
||||
}
|
||||
}
|
||||
if (outputAggregator.expectedValueSize() != null &&
|
||||
outputAggregator.expectedValueSize() != models.size()) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[{}] expects value array of size [{}] but number of models is [{}]",
|
||||
AGGREGATE_OUTPUT.getPreferredName(),
|
||||
outputAggregator.expectedValueSize(),
|
||||
models.size());
|
||||
}
|
||||
if ((this.targetType == TargetType.CLASSIFICATION) != (this.classificationLabels != null)) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[target_type] should be [classification] if [classification_labels] is provided, and vice versa");
|
||||
}
|
||||
this.models.forEach(TrainedModel::validate);
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private List<String> featureNames;
|
||||
private List<TrainedModel> trainedModels;
|
||||
private OutputAggregator outputAggregator = new WeightedSum();
|
||||
private TargetType targetType = TargetType.REGRESSION;
|
||||
private List<String> classificationLabels;
|
||||
private boolean modelsAreOrdered;
|
||||
|
||||
private Builder (boolean modelsAreOrdered) {
|
||||
this.modelsAreOrdered = modelsAreOrdered;
|
||||
}
|
||||
|
||||
private static Builder builderForParser() {
|
||||
return new Builder(false);
|
||||
}
|
||||
|
||||
public Builder() {
|
||||
this(true);
|
||||
}
|
||||
|
||||
public Builder setFeatureNames(List<String> featureNames) {
|
||||
this.featureNames = featureNames;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setTrainedModels(List<TrainedModel> trainedModels) {
|
||||
this.trainedModels = trainedModels;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setOutputAggregator(OutputAggregator outputAggregator) {
|
||||
this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setTargetType(TargetType targetType) {
|
||||
this.targetType = targetType;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setClassificationLabels(List<String> classificationLabels) {
|
||||
this.classificationLabels = classificationLabels;
|
||||
return this;
|
||||
}
|
||||
|
||||
private void setOutputAggregatorFromParser(List<OutputAggregator> outputAggregators) {
|
||||
if (outputAggregators.size() != 1) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must have exactly one aggregator defined.",
|
||||
AGGREGATE_OUTPUT.getPreferredName());
|
||||
}
|
||||
this.setOutputAggregator(outputAggregators.get(0));
|
||||
}
|
||||
|
||||
private void setTargetType(String targetType) {
|
||||
this.targetType = TargetType.fromString(targetType);
|
||||
}
|
||||
|
||||
private void setModelsAreOrdered(boolean value) {
|
||||
this.modelsAreOrdered = value;
|
||||
}
|
||||
|
||||
public Ensemble build() {
|
||||
// This is essentially a serialization error but the underlying xcontent parsing does not allow us to inject this requirement
|
||||
// So, we verify the models were parsed in an ordered fashion here instead.
|
||||
if (modelsAreOrdered == false && trainedModels != null && trainedModels.size() > 1) {
|
||||
throw ExceptionsHelper.badRequestException("[trained_models] needs to be an array of objects");
|
||||
}
|
||||
return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
|
||||
public interface LenientlyParsedOutputAggregator extends OutputAggregator {
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface OutputAggregator extends NamedXContentObject, NamedWriteable {
|
||||
|
||||
/**
|
||||
* @return The expected size of the values array when aggregating. `null` implies there is no expected size.
|
||||
*/
|
||||
Integer expectedValueSize();
|
||||
|
||||
/**
|
||||
* This pre-processes the values so that they may be passed directly to the {@link OutputAggregator#aggregate(List)} method.
|
||||
*
|
||||
* Two major types of pre-processed values could be returned:
|
||||
* - The confidence/probability scaled values given the input values (See: {@link WeightedMode#processValues(List)}
|
||||
* - A simple transformation of the passed values in preparation for aggregation (See: {@link WeightedSum#processValues(List)}
|
||||
* @param values the values to process
|
||||
* @return A new list containing the processed values or the same list if no processing is required
|
||||
*/
|
||||
List<Double> processValues(List<Double> values);
|
||||
|
||||
/**
|
||||
* Function to aggregate the processed values into a single double
|
||||
*
|
||||
* This may be as simple as returning the index of the maximum value.
|
||||
*
|
||||
* Or as complex as a mathematical reduction of all the passed values (i.e. summation, average, etc.).
|
||||
*
|
||||
* @param processedValues The values to aggregate
|
||||
* @return the aggregated value.
|
||||
*/
|
||||
double aggregate(List<Double> processedValues);
|
||||
|
||||
/**
|
||||
* @return The name of the output aggregator
|
||||
*/
|
||||
String getName();
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
|
||||
public interface StrictlyParsedOutputAggregator extends OutputAggregator {
|
||||
}
|
|
@ -0,0 +1,161 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax;
|
||||
|
||||
public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
|
||||
|
||||
public static final ParseField NAME = new ParseField("weighted_mode");
|
||||
public static final ParseField WEIGHTS = new ParseField("weights");
|
||||
|
||||
private static final ConstructingObjectParser<WeightedMode, Void> LENIENT_PARSER = createParser(true);
|
||||
private static final ConstructingObjectParser<WeightedMode, Void> STRICT_PARSER = createParser(false);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static ConstructingObjectParser<WeightedMode, Void> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<WeightedMode, Void> parser = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
lenient,
|
||||
a -> new WeightedMode((List<Double>)a[0]));
|
||||
parser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
|
||||
return parser;
|
||||
}
|
||||
|
||||
public static WeightedMode fromXContentStrict(XContentParser parser) {
|
||||
return STRICT_PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public static WeightedMode fromXContentLenient(XContentParser parser) {
|
||||
return LENIENT_PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final List<Double> weights;
|
||||
|
||||
WeightedMode() {
|
||||
this.weights = null;
|
||||
}
|
||||
|
||||
public WeightedMode(List<Double> weights) {
|
||||
this.weights = weights == null ? null : Collections.unmodifiableList(weights);
|
||||
}
|
||||
|
||||
public WeightedMode(StreamInput in) throws IOException {
|
||||
if (in.readBoolean()) {
|
||||
this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble));
|
||||
} else {
|
||||
this.weights = null;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer expectedValueSize() {
|
||||
return this.weights == null ? null : this.weights.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Double> processValues(List<Double> values) {
|
||||
Objects.requireNonNull(values, "values must not be null");
|
||||
if (weights != null && values.size() != weights.size()) {
|
||||
throw new IllegalArgumentException("values must be the same length as weights.");
|
||||
}
|
||||
List<Integer> freqArray = new ArrayList<>();
|
||||
Integer maxVal = 0;
|
||||
for (Double value : values) {
|
||||
if (value == null) {
|
||||
throw new IllegalArgumentException("values must not contain null values");
|
||||
}
|
||||
if (Double.isNaN(value) || Double.isInfinite(value) || value < 0.0 || value != Math.rint(value)) {
|
||||
throw new IllegalArgumentException("values must be whole, non-infinite, and positive");
|
||||
}
|
||||
Integer integerValue = value.intValue();
|
||||
freqArray.add(integerValue);
|
||||
if (integerValue > maxVal) {
|
||||
maxVal = integerValue;
|
||||
}
|
||||
}
|
||||
List<Double> frequencies = new ArrayList<>(Collections.nCopies(maxVal + 1, Double.NEGATIVE_INFINITY));
|
||||
for (int i = 0; i < freqArray.size(); i++) {
|
||||
Double weight = weights == null ? 1.0 : weights.get(i);
|
||||
Integer value = freqArray.get(i);
|
||||
Double frequency = frequencies.get(value) == Double.NEGATIVE_INFINITY ? weight : frequencies.get(value) + weight;
|
||||
frequencies.set(value, frequency);
|
||||
}
|
||||
return softMax(frequencies);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double aggregate(List<Double> values) {
|
||||
Objects.requireNonNull(values, "values must not be null");
|
||||
int bestValue = 0;
|
||||
double bestFreq = Double.NEGATIVE_INFINITY;
|
||||
for (int i = 0; i < values.size(); i++) {
|
||||
if (values.get(i) == null) {
|
||||
throw new IllegalArgumentException("values must not contain null values");
|
||||
}
|
||||
if (values.get(i) > bestFreq) {
|
||||
bestFreq = values.get(i);
|
||||
bestValue = i;
|
||||
}
|
||||
}
|
||||
return bestValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeBoolean(weights != null);
|
||||
if (weights != null) {
|
||||
out.writeCollection(weights, StreamOutput::writeDouble);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (weights != null) {
|
||||
builder.field(WEIGHTS.getPreferredName(), weights);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
WeightedMode that = (WeightedMode) o;
|
||||
return Objects.equals(weights, that.weights);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(weights);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,138 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
|
||||
|
||||
public static final ParseField NAME = new ParseField("weighted_sum");
|
||||
public static final ParseField WEIGHTS = new ParseField("weights");
|
||||
|
||||
private static final ConstructingObjectParser<WeightedSum, Void> LENIENT_PARSER = createParser(true);
|
||||
private static final ConstructingObjectParser<WeightedSum, Void> STRICT_PARSER = createParser(false);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static ConstructingObjectParser<WeightedSum, Void> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<WeightedSum, Void> parser = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
lenient,
|
||||
a -> new WeightedSum((List<Double>)a[0]));
|
||||
parser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
|
||||
return parser;
|
||||
}
|
||||
|
||||
public static WeightedSum fromXContentStrict(XContentParser parser) {
|
||||
return STRICT_PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public static WeightedSum fromXContentLenient(XContentParser parser) {
|
||||
return LENIENT_PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final List<Double> weights;
|
||||
|
||||
WeightedSum() {
|
||||
this.weights = null;
|
||||
}
|
||||
|
||||
public WeightedSum(List<Double> weights) {
|
||||
this.weights = weights == null ? null : Collections.unmodifiableList(weights);
|
||||
}
|
||||
|
||||
public WeightedSum(StreamInput in) throws IOException {
|
||||
if (in.readBoolean()) {
|
||||
this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble));
|
||||
} else {
|
||||
this.weights = null;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Double> processValues(List<Double> values) {
|
||||
Objects.requireNonNull(values, "values must not be null");
|
||||
if (weights == null) {
|
||||
return values;
|
||||
}
|
||||
if (values.size() != weights.size()) {
|
||||
throw new IllegalArgumentException("values must be the same length as weights.");
|
||||
}
|
||||
return IntStream.range(0, weights.size()).mapToDouble(i -> values.get(i) * weights.get(i)).boxed().collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public double aggregate(List<Double> values) {
|
||||
Objects.requireNonNull(values, "values must not be null");
|
||||
if (values.isEmpty()) {
|
||||
throw new IllegalArgumentException("values must not be empty");
|
||||
}
|
||||
Optional<Double> summation = values.stream().reduce(Double::sum);
|
||||
if (summation.isPresent()) {
|
||||
return summation.get();
|
||||
}
|
||||
throw new IllegalArgumentException("values must not contain null values");
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeBoolean(weights != null);
|
||||
if (weights != null) {
|
||||
out.writeCollection(weights, StreamOutput::writeDouble);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (weights != null) {
|
||||
builder.field(WEIGHTS.getPreferredName(), weights);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
WeightedSum that = (WeightedSum) o;
|
||||
return Objects.equals(weights, that.weights);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(weights);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer expectedValueSize() {
|
||||
return weights == null ? null : this.weights.size();
|
||||
}
|
||||
}
|
|
@ -9,11 +9,13 @@ import org.elasticsearch.common.ParseField;
|
|||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.util.CachedSupplier;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -31,10 +33,13 @@ import java.util.stream.Collectors;
|
|||
|
||||
public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel {
|
||||
|
||||
// TODO should we have regression/classification sub-classes that accept the builder?
|
||||
public static final ParseField NAME = new ParseField("tree");
|
||||
|
||||
public static final ParseField FEATURE_NAMES = new ParseField("feature_names");
|
||||
public static final ParseField TREE_STRUCTURE = new ParseField("tree_structure");
|
||||
public static final ParseField TARGET_TYPE = new ParseField("target_type");
|
||||
public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
|
||||
|
||||
private static final ObjectParser<Tree.Builder, Void> LENIENT_PARSER = createParser(true);
|
||||
private static final ObjectParser<Tree.Builder, Void> STRICT_PARSER = createParser(false);
|
||||
|
@ -46,6 +51,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
Tree.Builder::new);
|
||||
parser.declareStringArray(Tree.Builder::setFeatureNames, FEATURE_NAMES);
|
||||
parser.declareObjectArray(Tree.Builder::setNodes, (p, c) -> TreeNode.fromXContent(p, lenient), TREE_STRUCTURE);
|
||||
parser.declareString(Tree.Builder::setTargetType, TARGET_TYPE);
|
||||
parser.declareStringArray(Tree.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
|
||||
return parser;
|
||||
}
|
||||
|
||||
|
@ -59,15 +66,28 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
|
||||
private final List<String> featureNames;
|
||||
private final List<TreeNode> nodes;
|
||||
private final TargetType targetType;
|
||||
private final List<String> classificationLabels;
|
||||
private final CachedSupplier<Double> highestOrderCategory;
|
||||
|
||||
Tree(List<String> featureNames, List<TreeNode> nodes) {
|
||||
Tree(List<String> featureNames, List<TreeNode> nodes, TargetType targetType, List<String> classificationLabels) {
|
||||
this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
|
||||
this.nodes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE));
|
||||
this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
|
||||
this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
|
||||
this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue());
|
||||
}
|
||||
|
||||
public Tree(StreamInput in) throws IOException {
|
||||
this.featureNames = Collections.unmodifiableList(in.readStringList());
|
||||
this.nodes = Collections.unmodifiableList(in.readList(TreeNode::new));
|
||||
this.targetType = TargetType.fromStream(in);
|
||||
if (in.readBoolean()) {
|
||||
this.classificationLabels = Collections.unmodifiableList(in.readStringList());
|
||||
} else {
|
||||
this.classificationLabels = null;
|
||||
}
|
||||
this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -90,7 +110,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
return infer(features);
|
||||
}
|
||||
|
||||
private double infer(List<Double> features) {
|
||||
@Override
|
||||
public double infer(List<Double> features) {
|
||||
TreeNode node = nodes.get(0);
|
||||
while(node.isLeaf() == false) {
|
||||
node = nodes.get(node.compare(features));
|
||||
|
@ -115,13 +136,40 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
}
|
||||
|
||||
@Override
|
||||
public boolean isClassification() {
|
||||
return false;
|
||||
public TargetType targetType() {
|
||||
return targetType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Double> inferProbabilities(Map<String, Object> fields) {
|
||||
throw new UnsupportedOperationException("Cannot infer probabilities against a regression model.");
|
||||
public List<Double> classificationProbability(Map<String, Object> fields) {
|
||||
if ((targetType == TargetType.CLASSIFICATION) == false) {
|
||||
throw new UnsupportedOperationException(
|
||||
"Cannot determine classification probability with target_type [" + targetType.toString() + "]");
|
||||
}
|
||||
return classificationProbability(featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Double> classificationProbability(List<Double> fields) {
|
||||
if ((targetType == TargetType.CLASSIFICATION) == false) {
|
||||
throw new UnsupportedOperationException(
|
||||
"Cannot determine classification probability with target_type [" + targetType.toString() + "]");
|
||||
}
|
||||
double label = infer(fields);
|
||||
// If we are classification, we should assume that the inference return value is whole.
|
||||
assert label == Math.rint(label);
|
||||
double maxCategory = this.highestOrderCategory.get();
|
||||
// If we are classification, we should assume that the largest leaf value is whole.
|
||||
assert maxCategory == Math.rint(maxCategory);
|
||||
List<Double> list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0));
|
||||
// TODO, eventually have TreeNodes contain confidence levels
|
||||
list.set(Double.valueOf(label).intValue(), 1.0);
|
||||
return list;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> classificationLabels() {
|
||||
return classificationLabels;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -133,6 +181,11 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeStringCollection(featureNames);
|
||||
out.writeCollection(nodes);
|
||||
targetType.writeTo(out);
|
||||
out.writeBoolean(classificationLabels != null);
|
||||
if (classificationLabels != null) {
|
||||
out.writeStringCollection(classificationLabels);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -140,6 +193,10 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
builder.startObject();
|
||||
builder.field(FEATURE_NAMES.getPreferredName(), featureNames);
|
||||
builder.field(TREE_STRUCTURE.getPreferredName(), nodes);
|
||||
builder.field(TARGET_TYPE.getPreferredName(), targetType.toString());
|
||||
if(classificationLabels != null) {
|
||||
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -155,22 +212,96 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Tree that = (Tree) o;
|
||||
return Objects.equals(featureNames, that.featureNames)
|
||||
&& Objects.equals(nodes, that.nodes);
|
||||
&& Objects.equals(nodes, that.nodes)
|
||||
&& Objects.equals(targetType, that.targetType)
|
||||
&& Objects.equals(classificationLabels, that.classificationLabels);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(featureNames, nodes);
|
||||
return Objects.hash(featureNames, nodes, targetType, classificationLabels);
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void validate() {
|
||||
checkTargetType();
|
||||
detectMissingNodes();
|
||||
detectCycle();
|
||||
}
|
||||
|
||||
private void checkTargetType() {
|
||||
if ((this.targetType == TargetType.CLASSIFICATION) != (this.classificationLabels != null)) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[target_type] should be [classification] if [classification_labels] is provided, and vice versa");
|
||||
}
|
||||
}
|
||||
|
||||
private void detectCycle() {
|
||||
if (nodes.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
Set<Integer> visited = new HashSet<>(nodes.size());
|
||||
Queue<Integer> toVisit = new ArrayDeque<>(nodes.size());
|
||||
toVisit.add(0);
|
||||
while(toVisit.isEmpty() == false) {
|
||||
Integer nodeIdx = toVisit.remove();
|
||||
if (visited.contains(nodeIdx)) {
|
||||
throw ExceptionsHelper.badRequestException("[tree] contains cycle at node {}", nodeIdx);
|
||||
}
|
||||
visited.add(nodeIdx);
|
||||
TreeNode treeNode = nodes.get(nodeIdx);
|
||||
if (treeNode.getLeftChild() >= 0) {
|
||||
toVisit.add(treeNode.getLeftChild());
|
||||
}
|
||||
if (treeNode.getRightChild() >= 0) {
|
||||
toVisit.add(treeNode.getRightChild());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void detectMissingNodes() {
|
||||
if (nodes.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
List<Integer> missingNodes = new ArrayList<>();
|
||||
for (int i = 0; i < nodes.size(); i++) {
|
||||
TreeNode currentNode = nodes.get(i);
|
||||
if (currentNode == null) {
|
||||
continue;
|
||||
}
|
||||
if (nodeMissing(currentNode.getLeftChild(), nodes)) {
|
||||
missingNodes.add(currentNode.getLeftChild());
|
||||
}
|
||||
if (nodeMissing(currentNode.getRightChild(), nodes)) {
|
||||
missingNodes.add(currentNode.getRightChild());
|
||||
}
|
||||
}
|
||||
if (missingNodes.isEmpty() == false) {
|
||||
throw ExceptionsHelper.badRequestException("[tree] contains missing nodes {}", missingNodes);
|
||||
}
|
||||
}
|
||||
|
||||
private static boolean nodeMissing(int nodeIdx, List<TreeNode> nodes) {
|
||||
return nodeIdx >= nodes.size();
|
||||
}
|
||||
|
||||
private Double maxLeafValue() {
|
||||
return targetType == TargetType.CLASSIFICATION ?
|
||||
this.nodes.stream().filter(TreeNode::isLeaf).mapToDouble(TreeNode::getLeafValue).max().getAsDouble() :
|
||||
null;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private List<String> featureNames;
|
||||
private ArrayList<TreeNode.Builder> nodes;
|
||||
private int numNodes;
|
||||
private TargetType targetType = TargetType.REGRESSION;
|
||||
private List<String> classificationLabels;
|
||||
|
||||
public Builder() {
|
||||
nodes = new ArrayList<>();
|
||||
|
@ -185,13 +316,18 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setRoot(TreeNode.Builder root) {
|
||||
nodes.set(0, root);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder addNode(TreeNode.Builder node) {
|
||||
nodes.add(node);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setNodes(List<TreeNode.Builder> nodes) {
|
||||
this.nodes = new ArrayList<>(nodes);
|
||||
this.nodes = new ArrayList<>(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE.getPreferredName()));
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -199,6 +335,21 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
return setNodes(Arrays.asList(nodes));
|
||||
}
|
||||
|
||||
|
||||
public Builder setTargetType(TargetType targetType) {
|
||||
this.targetType = targetType;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setClassificationLabels(List<String> classificationLabels) {
|
||||
this.classificationLabels = classificationLabels;
|
||||
return this;
|
||||
}
|
||||
|
||||
private void setTargetType(String targetType) {
|
||||
this.targetType = TargetType.fromString(targetType);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a decision node. Space for the child nodes is allocated
|
||||
* @param nodeIndex Where to place the node. This is either 0 (root) or an existing child node index
|
||||
|
@ -231,61 +382,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
return node;
|
||||
}
|
||||
|
||||
void detectCycle(List<TreeNode.Builder> nodes) {
|
||||
if (nodes.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
Set<Integer> visited = new HashSet<>();
|
||||
Queue<Integer> toVisit = new ArrayDeque<>(nodes.size());
|
||||
toVisit.add(0);
|
||||
while(toVisit.isEmpty() == false) {
|
||||
Integer nodeIdx = toVisit.remove();
|
||||
if (visited.contains(nodeIdx)) {
|
||||
throw new IllegalArgumentException("[tree] contains cycle at node " + nodeIdx);
|
||||
}
|
||||
visited.add(nodeIdx);
|
||||
TreeNode.Builder treeNode = nodes.get(nodeIdx);
|
||||
if (treeNode.getLeftChild() != null) {
|
||||
toVisit.add(treeNode.getLeftChild());
|
||||
}
|
||||
if (treeNode.getRightChild() != null) {
|
||||
toVisit.add(treeNode.getRightChild());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void detectNullOrMissingNode(List<TreeNode.Builder> nodes) {
|
||||
if (nodes.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
if (nodes.get(0) == null) {
|
||||
throw new IllegalArgumentException("[tree] must have non-null root node.");
|
||||
}
|
||||
List<Integer> nullOrMissingNodes = new ArrayList<>();
|
||||
for (int i = 0; i < nodes.size(); i++) {
|
||||
TreeNode.Builder currentNode = nodes.get(i);
|
||||
if (currentNode == null) {
|
||||
continue;
|
||||
}
|
||||
if (nodeNullOrMissing(currentNode.getLeftChild())) {
|
||||
nullOrMissingNodes.add(currentNode.getLeftChild());
|
||||
}
|
||||
if (nodeNullOrMissing(currentNode.getRightChild())) {
|
||||
nullOrMissingNodes.add(currentNode.getRightChild());
|
||||
}
|
||||
}
|
||||
if (nullOrMissingNodes.isEmpty() == false) {
|
||||
throw new IllegalArgumentException("[tree] contains null or missing nodes " + nullOrMissingNodes);
|
||||
}
|
||||
}
|
||||
|
||||
private boolean nodeNullOrMissing(Integer nodeIdx) {
|
||||
if (nodeIdx == null) {
|
||||
return false;
|
||||
}
|
||||
return nodeIdx >= nodes.size() || nodes.get(nodeIdx) == null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the node at {@code nodeIndex} to a leaf node.
|
||||
* @param nodeIndex The index as allocated by a call to {@link #addJunction(int, int, boolean, double)}
|
||||
|
@ -301,10 +397,13 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
}
|
||||
|
||||
public Tree build() {
|
||||
detectNullOrMissingNode(nodes);
|
||||
detectCycle(nodes);
|
||||
if (nodes.stream().anyMatch(Objects::isNull)) {
|
||||
throw ExceptionsHelper.badRequestException("[tree] cannot contain null nodes");
|
||||
}
|
||||
return new Tree(featureNames,
|
||||
nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList()));
|
||||
nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList()),
|
||||
targetType,
|
||||
classificationLabels);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -143,7 +143,7 @@ public class TreeNode implements ToXContentObject, Writeable {
|
|||
}
|
||||
|
||||
public boolean isLeaf() {
|
||||
return leftChild < 1;
|
||||
return leftChild < 0;
|
||||
}
|
||||
|
||||
public int compare(List<Double> features) {
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.utils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public final class Statistics {
|
||||
|
||||
private Statistics(){}
|
||||
|
||||
/**
|
||||
* Calculates the softMax of the passed values.
|
||||
*
|
||||
* Any {@link Double#isInfinite()}, {@link Double#NaN}, or `null` values are ignored in calculation and returned as 0.0 in the
|
||||
* softMax.
|
||||
* @param values Values on which to run SoftMax.
|
||||
* @return A new list containing the softmax of the passed values
|
||||
*/
|
||||
public static List<Double> softMax(List<Double> values) {
|
||||
Double expSum = 0.0;
|
||||
Double max = values.stream().filter(v -> isInvalid(v) == false).max(Double::compareTo).orElse(null);
|
||||
if (max == null) {
|
||||
throw new IllegalArgumentException("no valid values present");
|
||||
}
|
||||
List<Double> exps = values.stream().map(v -> isInvalid(v) ? Double.NEGATIVE_INFINITY : v - max)
|
||||
.collect(Collectors.toList());
|
||||
for (int i = 0; i < exps.size(); i++) {
|
||||
if (isInvalid(exps.get(i)) == false) {
|
||||
Double exp = Math.exp(exps.get(i));
|
||||
expSum += exp;
|
||||
exps.set(i, exp);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < exps.size(); i++) {
|
||||
if (isInvalid(exps.get(i))) {
|
||||
exps.set(i, 0.0);
|
||||
} else {
|
||||
exps.set(i, exps.get(i)/expSum);
|
||||
}
|
||||
}
|
||||
return exps;
|
||||
}
|
||||
|
||||
public static boolean isInvalid(Double v) {
|
||||
return v == null || Double.isInfinite(v) || Double.isNaN(v);
|
||||
}
|
||||
|
||||
}
|
|
@ -17,6 +17,7 @@ import org.elasticsearch.test.AbstractXContentTestCase;
|
|||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.EnsembleTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
|
||||
|
@ -157,7 +158,7 @@ public class NamedXContentObjectsTests extends AbstractXContentTestCase<NamedXCo
|
|||
NamedObjectContainer container = new NamedObjectContainer();
|
||||
container.setPreProcessors(preProcessors);
|
||||
container.setUseExplicitPreprocessorOrder(true);
|
||||
container.setModel(TreeTests.buildRandomTree(5, 4));
|
||||
container.setModel(randomFrom(TreeTests.createRandom(), EnsembleTests.createRandom()));
|
||||
return container;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,402 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
||||
|
||||
private boolean lenient;
|
||||
|
||||
@Before
|
||||
public void chooseStrictOrLenient() {
|
||||
lenient = randomBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return lenient;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||
return field -> !field.isEmpty();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Ensemble doParseInstance(XContentParser parser) throws IOException {
|
||||
return lenient ? Ensemble.fromXContentLenient(parser) : Ensemble.fromXContentStrict(parser);
|
||||
}
|
||||
|
||||
public static Ensemble createRandom() {
|
||||
int numberOfFeatures = randomIntBetween(1, 10);
|
||||
List<String> featureNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numberOfFeatures).collect(Collectors.toList());
|
||||
int numberOfModels = randomIntBetween(1, 10);
|
||||
List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6))
|
||||
.limit(numberOfModels)
|
||||
.collect(Collectors.toList());
|
||||
List<Double> weights = randomBoolean() ?
|
||||
null :
|
||||
Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
|
||||
OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights));
|
||||
List<String> categoryLabels = null;
|
||||
if (randomBoolean()) {
|
||||
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
|
||||
}
|
||||
|
||||
return new Ensemble(featureNames,
|
||||
models,
|
||||
outputAggregator,
|
||||
randomFrom(TargetType.values()),
|
||||
categoryLabels);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Ensemble createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Ensemble> instanceReader() {
|
||||
return Ensemble::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
|
||||
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||
return new NamedWriteableRegistry(entries);
|
||||
}
|
||||
|
||||
public void testEnsembleWithModelsThatHaveDifferentFeatureNames() {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar", "baz", "farequote");
|
||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
|
||||
Ensemble.builder().setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(TreeTests.buildRandomTree(Arrays.asList("bar", "foo", "baz", "farequote"), 6)))
|
||||
.build()
|
||||
.validate();
|
||||
});
|
||||
assertThat(ex.getMessage(), equalTo("[feature_names] must be the same and in the same order for each of the trained_models"));
|
||||
|
||||
ex = expectThrows(ElasticsearchException.class, () -> {
|
||||
Ensemble.builder().setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(TreeTests.buildRandomTree(Arrays.asList("completely_different"), 6)))
|
||||
.build()
|
||||
.validate();
|
||||
});
|
||||
assertThat(ex.getMessage(), equalTo("[feature_names] must be the same and in the same order for each of the trained_models"));
|
||||
}
|
||||
|
||||
public void testEnsembleWithAggregatedOutputDifferingFromTrainedModels() {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
int numberOfModels = 5;
|
||||
List<Double> weights = new ArrayList<>(numberOfModels + 2);
|
||||
for (int i = 0; i < numberOfModels + 2; i++) {
|
||||
weights.add(randomDouble());
|
||||
}
|
||||
OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights));
|
||||
|
||||
List<TrainedModel> models = new ArrayList<>(numberOfModels);
|
||||
for (int i = 0; i < numberOfModels; i++) {
|
||||
models.add(TreeTests.buildRandomTree(featureNames, 6));
|
||||
}
|
||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
|
||||
Ensemble.builder()
|
||||
.setTrainedModels(models)
|
||||
.setOutputAggregator(outputAggregator)
|
||||
.setFeatureNames(featureNames)
|
||||
.build()
|
||||
.validate();
|
||||
});
|
||||
assertThat(ex.getMessage(), equalTo("[aggregate_output] expects value array of size [7] but number of models is [5]"));
|
||||
}
|
||||
|
||||
public void testEnsembleWithInvalidModel() {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
expectThrows(ElasticsearchException.class, () -> {
|
||||
Ensemble.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(
|
||||
// Tree with loop
|
||||
Tree.builder()
|
||||
.setNodes(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setSplitFeature(1)
|
||||
.setThreshold(randomDouble()),
|
||||
TreeNode.builder(0)
|
||||
.setLeftChild(0)
|
||||
.setSplitFeature(1)
|
||||
.setThreshold(randomDouble()))
|
||||
.setFeatureNames(featureNames)
|
||||
.build()))
|
||||
.build()
|
||||
.validate();
|
||||
});
|
||||
}
|
||||
|
||||
public void testEnsembleWithTargetTypeAndLabelsMismatch() {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
String msg = "[target_type] should be [classification] if [classification_labels] is provided, and vice versa";
|
||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
|
||||
Ensemble.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(
|
||||
Tree.builder()
|
||||
.setNodes(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setSplitFeature(1)
|
||||
.setThreshold(randomDouble()))
|
||||
.setFeatureNames(featureNames)
|
||||
.build()))
|
||||
.setClassificationLabels(Arrays.asList("label1", "label2"))
|
||||
.build()
|
||||
.validate();
|
||||
});
|
||||
assertThat(ex.getMessage(), equalTo(msg));
|
||||
ex = expectThrows(ElasticsearchException.class, () -> {
|
||||
Ensemble.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(
|
||||
Tree.builder()
|
||||
.setNodes(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setSplitFeature(1)
|
||||
.setThreshold(randomDouble()))
|
||||
.setFeatureNames(featureNames)
|
||||
.build()))
|
||||
.setTargetType(TargetType.CLASSIFICATION)
|
||||
.build()
|
||||
.validate();
|
||||
});
|
||||
assertThat(ex.getMessage(), equalTo(msg));
|
||||
}
|
||||
|
||||
public void testClassificationProbability() {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
Tree tree1 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(0)
|
||||
.setThreshold(0.5))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(1.0))
|
||||
.addNode(TreeNode.builder(2)
|
||||
.setThreshold(0.8)
|
||||
.setSplitFeature(1)
|
||||
.setLeftChild(3)
|
||||
.setRightChild(4))
|
||||
.addNode(TreeNode.builder(3).setLeafValue(0.0))
|
||||
.addNode(TreeNode.builder(4).setLeafValue(1.0)).build();
|
||||
Tree tree2 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(0)
|
||||
.setThreshold(0.5))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(0.0))
|
||||
.addNode(TreeNode.builder(2).setLeafValue(1.0))
|
||||
.build();
|
||||
Tree tree3 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(1)
|
||||
.setThreshold(1.0))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(1.0))
|
||||
.addNode(TreeNode.builder(2).setLeafValue(0.0))
|
||||
.build();
|
||||
Ensemble ensemble = Ensemble.builder()
|
||||
.setTargetType(TargetType.CLASSIFICATION)
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
|
||||
.setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0)))
|
||||
.build();
|
||||
|
||||
List<Double> featureVector = Arrays.asList(0.4, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
List<Double> expected = Arrays.asList(0.231475216, 0.768524783);
|
||||
double eps = 0.000001;
|
||||
List<Double> probabilities = ensemble.classificationProbability(featureMap);
|
||||
for(int i = 0; i < expected.size(); i++) {
|
||||
assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
|
||||
}
|
||||
|
||||
featureVector = Arrays.asList(2.0, 0.7);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
expected = Arrays.asList(0.3100255188, 0.689974481);
|
||||
probabilities = ensemble.classificationProbability(featureMap);
|
||||
for(int i = 0; i < expected.size(); i++) {
|
||||
assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
|
||||
}
|
||||
|
||||
featureVector = Arrays.asList(0.0, 1.0);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
expected = Arrays.asList(0.231475216, 0.768524783);
|
||||
probabilities = ensemble.classificationProbability(featureMap);
|
||||
for(int i = 0; i < expected.size(); i++) {
|
||||
assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
|
||||
}
|
||||
}
|
||||
|
||||
public void testClassificationInference() {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
Tree tree1 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(0)
|
||||
.setThreshold(0.5))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(1.0))
|
||||
.addNode(TreeNode.builder(2)
|
||||
.setThreshold(0.8)
|
||||
.setSplitFeature(1)
|
||||
.setLeftChild(3)
|
||||
.setRightChild(4))
|
||||
.addNode(TreeNode.builder(3).setLeafValue(0.0))
|
||||
.addNode(TreeNode.builder(4).setLeafValue(1.0)).build();
|
||||
Tree tree2 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(0)
|
||||
.setThreshold(0.5))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(0.0))
|
||||
.addNode(TreeNode.builder(2).setLeafValue(1.0))
|
||||
.build();
|
||||
Tree tree3 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(1)
|
||||
.setThreshold(1.0))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(1.0))
|
||||
.addNode(TreeNode.builder(2).setLeafValue(0.0))
|
||||
.build();
|
||||
Ensemble ensemble = Ensemble.builder()
|
||||
.setTargetType(TargetType.CLASSIFICATION)
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
|
||||
.setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0)))
|
||||
.build();
|
||||
|
||||
List<Double> featureVector = Arrays.asList(0.4, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
|
||||
|
||||
featureVector = Arrays.asList(2.0, 0.7);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
|
||||
|
||||
featureVector = Arrays.asList(0.0, 1.0);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
|
||||
}
|
||||
|
||||
public void testRegressionInference() {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
Tree tree1 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(0)
|
||||
.setThreshold(0.5))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(0.3))
|
||||
.addNode(TreeNode.builder(2)
|
||||
.setThreshold(0.8)
|
||||
.setSplitFeature(1)
|
||||
.setLeftChild(3)
|
||||
.setRightChild(4))
|
||||
.addNode(TreeNode.builder(3).setLeafValue(0.1))
|
||||
.addNode(TreeNode.builder(4).setLeafValue(0.2)).build();
|
||||
Tree tree2 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(0)
|
||||
.setThreshold(0.5))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(1.5))
|
||||
.addNode(TreeNode.builder(2).setLeafValue(0.9))
|
||||
.build();
|
||||
Ensemble ensemble = Ensemble.builder()
|
||||
.setTargetType(TargetType.REGRESSION)
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(tree1, tree2))
|
||||
.setOutputAggregator(new WeightedSum(Arrays.asList(0.5, 0.5)))
|
||||
.build();
|
||||
|
||||
List<Double> featureVector = Arrays.asList(0.4, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(0.9, ensemble.infer(featureMap), 0.00001);
|
||||
|
||||
featureVector = Arrays.asList(2.0, 0.7);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(0.5, ensemble.infer(featureMap), 0.00001);
|
||||
|
||||
// Test with NO aggregator supplied, verifies default behavior of non-weighted sum
|
||||
ensemble = Ensemble.builder()
|
||||
.setTargetType(TargetType.REGRESSION)
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(tree1, tree2))
|
||||
.build();
|
||||
|
||||
featureVector = Arrays.asList(0.4, 0.0);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(1.8, ensemble.infer(featureMap), 0.00001);
|
||||
|
||||
featureVector = Arrays.asList(2.0, 0.7);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
|
||||
}
|
||||
|
||||
private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
|
||||
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public abstract class WeightedAggregatorTests<T extends OutputAggregator> extends AbstractSerializingTestCase<T> {
|
||||
|
||||
protected boolean lenient;
|
||||
|
||||
@Before
|
||||
public void chooseStrictOrLenient() {
|
||||
lenient = randomBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return lenient;
|
||||
}
|
||||
|
||||
public void testWithNullValues() {
|
||||
OutputAggregator outputAggregator = createTestInstance();
|
||||
NullPointerException ex = expectThrows(NullPointerException.class, () -> outputAggregator.processValues(null));
|
||||
assertThat(ex.getMessage(), equalTo("values must not be null"));
|
||||
}
|
||||
|
||||
public void testWithValuesOfWrongLength() {
|
||||
int numberOfValues = randomIntBetween(5, 10);
|
||||
List<Double> values = new ArrayList<>(numberOfValues);
|
||||
for (int i = 0; i < numberOfValues; i++) {
|
||||
values.add(randomDouble());
|
||||
}
|
||||
|
||||
OutputAggregator outputAggregatorWithTooFewWeights = createTestInstance(randomIntBetween(1, numberOfValues - 1));
|
||||
expectThrows(IllegalArgumentException.class, () -> outputAggregatorWithTooFewWeights.processValues(values));
|
||||
|
||||
OutputAggregator outputAggregatorWithTooManyWeights = createTestInstance(randomIntBetween(numberOfValues + 1, numberOfValues + 10));
|
||||
expectThrows(IllegalArgumentException.class, () -> outputAggregatorWithTooManyWeights.processValues(values));
|
||||
}
|
||||
|
||||
abstract T createTestInstance(int numberOfWeights);
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class WeightedModeTests extends WeightedAggregatorTests<WeightedMode> {
|
||||
|
||||
@Override
|
||||
WeightedMode createTestInstance(int numberOfWeights) {
|
||||
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList());
|
||||
return new WeightedMode(weights);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected WeightedMode doParseInstance(XContentParser parser) throws IOException {
|
||||
return lenient ? WeightedMode.fromXContentLenient(parser) : WeightedMode.fromXContentStrict(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected WeightedMode createTestInstance() {
|
||||
return randomBoolean() ? new WeightedMode() : createTestInstance(randomIntBetween(1, 100));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<WeightedMode> instanceReader() {
|
||||
return WeightedMode::new;
|
||||
}
|
||||
|
||||
public void testAggregate() {
|
||||
List<Double> ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0);
|
||||
List<Double> values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0);
|
||||
|
||||
WeightedMode weightedMode = new WeightedMode(ones);
|
||||
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0));
|
||||
|
||||
List<Double> variedWeights = Arrays.asList(1.0, -1.0, .5, 1.0, 5.0);
|
||||
|
||||
weightedMode = new WeightedMode(variedWeights);
|
||||
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(5.0));
|
||||
|
||||
weightedMode = new WeightedMode();
|
||||
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class WeightedSumTests extends WeightedAggregatorTests<WeightedSum> {
|
||||
|
||||
@Override
|
||||
WeightedSum createTestInstance(int numberOfWeights) {
|
||||
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList());
|
||||
return new WeightedSum(weights);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected WeightedSum doParseInstance(XContentParser parser) throws IOException {
|
||||
return lenient ? WeightedSum.fromXContentLenient(parser) : WeightedSum.fromXContentStrict(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected WeightedSum createTestInstance() {
|
||||
return randomBoolean() ? new WeightedSum() : createTestInstance(randomIntBetween(1, 100));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<WeightedSum> instanceReader() {
|
||||
return WeightedSum::new;
|
||||
}
|
||||
|
||||
public void testAggregate() {
|
||||
List<Double> ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0);
|
||||
List<Double> values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0);
|
||||
|
||||
WeightedSum weightedSum = new WeightedSum(ones);
|
||||
assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0));
|
||||
|
||||
List<Double> variedWeights = Arrays.asList(1.0, -1.0, .5, 1.0, 5.0);
|
||||
|
||||
weightedSum = new WeightedSum(variedWeights);
|
||||
assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(28.0));
|
||||
|
||||
weightedSum = new WeightedSum();
|
||||
assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0));
|
||||
}
|
||||
}
|
|
@ -5,9 +5,12 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -47,23 +50,23 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
|||
return field -> field.startsWith("feature_names");
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
protected Tree createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static Tree createRandom() {
|
||||
return buildRandomTree(randomIntBetween(2, 15), 6);
|
||||
}
|
||||
|
||||
public static Tree buildRandomTree(int numFeatures, int depth) {
|
||||
|
||||
Tree.Builder builder = Tree.builder();
|
||||
List<String> featureNames = new ArrayList<>(numFeatures);
|
||||
for(int i = 0; i < numFeatures; i++) {
|
||||
int numberOfFeatures = randomIntBetween(1, 10);
|
||||
List<String> featureNames = new ArrayList<>();
|
||||
for (int i = 0; i < numberOfFeatures; i++) {
|
||||
featureNames.add(randomAlphaOfLength(10));
|
||||
}
|
||||
return buildRandomTree(featureNames, 6);
|
||||
}
|
||||
|
||||
public static Tree buildRandomTree(List<String> featureNames, int depth) {
|
||||
Tree.Builder builder = Tree.builder();
|
||||
int numFeatures = featureNames.size() - 1;
|
||||
builder.setFeatureNames(featureNames);
|
||||
|
||||
TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble());
|
||||
|
@ -84,8 +87,14 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
|||
}
|
||||
childNodes = nextNodes;
|
||||
}
|
||||
List<String> categoryLabels = null;
|
||||
if (randomBoolean()) {
|
||||
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
return builder.setTargetType(randomFrom(TargetType.values()))
|
||||
.setClassificationLabels(categoryLabels)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -96,7 +105,7 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
|||
public void testInfer() {
|
||||
// Build a tree with 2 nodes and 3 leaves using 2 features
|
||||
// The leaves have unique values 0.1, 0.2, 0.3
|
||||
Tree.Builder builder = Tree.builder();
|
||||
Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
|
||||
TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
|
||||
builder.addLeaf(rootNode.getRightChild(), 0.3);
|
||||
TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
|
||||
|
@ -124,37 +133,76 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
|||
assertEquals(0.2, tree.infer(featureMap), 0.00001);
|
||||
}
|
||||
|
||||
public void testTreeClassificationProbability() {
|
||||
// Build a tree with 2 nodes and 3 leaves using 2 features
|
||||
// The leaves have unique values 0.1, 0.2, 0.3
|
||||
Tree.Builder builder = Tree.builder().setTargetType(TargetType.CLASSIFICATION);
|
||||
TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
|
||||
builder.addLeaf(rootNode.getRightChild(), 1.0);
|
||||
TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
|
||||
builder.addLeaf(leftChildNode.getLeftChild(), 1.0);
|
||||
builder.addLeaf(leftChildNode.getRightChild(), 0.0);
|
||||
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
Tree tree = builder.setFeatureNames(featureNames).build();
|
||||
|
||||
// This feature vector should hit the right child of the root node
|
||||
List<Double> featureVector = Arrays.asList(0.6, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap));
|
||||
|
||||
// This should hit the left child of the left child of the root node
|
||||
// i.e. it takes the path left, left
|
||||
featureVector = Arrays.asList(0.3, 0.7);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap));
|
||||
|
||||
// This should hit the right child of the left child of the root node
|
||||
// i.e. it takes the path left, right
|
||||
featureVector = Arrays.asList(0.3, 0.9);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(Arrays.asList(1.0, 0.0), tree.classificationProbability(featureMap));
|
||||
}
|
||||
|
||||
public void testTreeWithNullRoot() {
|
||||
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class,
|
||||
() -> Tree.builder().setNodes(Collections.singletonList(null))
|
||||
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> Tree.builder()
|
||||
.setNodes(Collections.singletonList(null))
|
||||
.setFeatureNames(Arrays.asList("foo", "bar"))
|
||||
.build());
|
||||
assertThat(ex.getMessage(), equalTo("[tree] must have non-null root node."));
|
||||
assertThat(ex.getMessage(), equalTo("[tree] cannot contain null nodes"));
|
||||
}
|
||||
|
||||
public void testTreeWithInvalidNode() {
|
||||
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class,
|
||||
() -> Tree.builder().setNodes(TreeNode.builder(0)
|
||||
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> Tree.builder()
|
||||
.setNodes(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setSplitFeature(1)
|
||||
.setThreshold(randomDouble()))
|
||||
.build());
|
||||
assertThat(ex.getMessage(), equalTo("[tree] contains null or missing nodes [1]"));
|
||||
.setFeatureNames(Arrays.asList("foo", "bar"))
|
||||
.build().validate());
|
||||
assertThat(ex.getMessage(), equalTo("[tree] contains missing nodes [1]"));
|
||||
}
|
||||
|
||||
public void testTreeWithNullNode() {
|
||||
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class,
|
||||
() -> Tree.builder().setNodes(TreeNode.builder(0)
|
||||
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> Tree.builder()
|
||||
.setNodes(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setSplitFeature(1)
|
||||
.setThreshold(randomDouble()),
|
||||
null)
|
||||
.build());
|
||||
assertThat(ex.getMessage(), equalTo("[tree] contains null or missing nodes [1]"));
|
||||
.setFeatureNames(Arrays.asList("foo", "bar"))
|
||||
.build()
|
||||
.validate());
|
||||
assertThat(ex.getMessage(), equalTo("[tree] cannot contain null nodes"));
|
||||
}
|
||||
|
||||
public void testTreeWithCycle() {
|
||||
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class,
|
||||
() -> Tree.builder().setNodes(TreeNode.builder(0)
|
||||
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> Tree.builder()
|
||||
.setNodes(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setSplitFeature(1)
|
||||
.setThreshold(randomDouble()),
|
||||
|
@ -162,10 +210,41 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
|||
.setLeftChild(0)
|
||||
.setSplitFeature(1)
|
||||
.setThreshold(randomDouble()))
|
||||
.build());
|
||||
.setFeatureNames(Arrays.asList("foo", "bar"))
|
||||
.build()
|
||||
.validate());
|
||||
assertThat(ex.getMessage(), equalTo("[tree] contains cycle at node 0"));
|
||||
}
|
||||
|
||||
public void testTreeWithTargetTypeAndLabelsMismatch() {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
String msg = "[target_type] should be [classification] if [classification_labels] is provided, and vice versa";
|
||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
|
||||
Tree.builder()
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setSplitFeature(1)
|
||||
.setThreshold(randomDouble()))
|
||||
.setFeatureNames(featureNames)
|
||||
.setClassificationLabels(Arrays.asList("label1", "label2"))
|
||||
.build()
|
||||
.validate();
|
||||
});
|
||||
assertThat(ex.getMessage(), equalTo(msg));
|
||||
ex = expectThrows(ElasticsearchException.class, () -> {
|
||||
Tree.builder()
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setSplitFeature(1)
|
||||
.setThreshold(randomDouble()))
|
||||
.setFeatureNames(featureNames)
|
||||
.setTargetType(TargetType.CLASSIFICATION)
|
||||
.build()
|
||||
.validate();
|
||||
});
|
||||
assertThat(ex.getMessage(), equalTo(msg));
|
||||
}
|
||||
|
||||
private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
|
||||
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.utils;
|
||||
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
|
||||
public class StatisticsTests extends ESTestCase {
|
||||
|
||||
public void testSoftMax() {
|
||||
List<Double> values = Arrays.asList(Double.NEGATIVE_INFINITY, 1.0, -0.5, null, Double.NaN, Double.POSITIVE_INFINITY, 1.0, 5.0);
|
||||
List<Double> softMax = Statistics.softMax(values);
|
||||
|
||||
List<Double> expected = Arrays.asList(0.0, 0.017599040, 0.003926876, 0.0, 0.0, 0.0, 0.017599040, 0.960875042);
|
||||
|
||||
for(int i = 0; i < expected.size(); i++) {
|
||||
assertThat(softMax.get(i), closeTo(expected.get(i), 0.000001));
|
||||
}
|
||||
}
|
||||
|
||||
public void testSoftMaxWithNoValidValues() {
|
||||
List<Double> values = Arrays.asList(Double.NEGATIVE_INFINITY, null, Double.NaN, Double.POSITIVE_INFINITY);
|
||||
expectThrows(IllegalArgumentException.class, () -> Statistics.softMax(values));
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue