[ML][Inference] adding ensemble model objects (#47241) (#47438)

* [ML][Inference] adding ensemble model objects

* addressing PR comments

* Update TreeTests.java

* addressing PR comments

* fixing test
This commit is contained in:
Benjamin Trent 2019-10-02 09:49:46 -04:00 committed by GitHub
parent b9541eb3af
commit 2228a7dd8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 2421 additions and 119 deletions

View File

@ -19,6 +19,10 @@
package org.elasticsearch.client.ml.inference;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
@ -47,6 +51,15 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
// Model
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Tree.NAME), Tree::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Ensemble.NAME), Ensemble::fromXContent));
// Aggregating output
namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class,
new ParseField(WeightedMode.NAME),
WeightedMode::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class,
new ParseField(WeightedSum.NAME),
WeightedSum::fromXContent));
return namedXContent;
}

View File

@ -0,0 +1,35 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel;
import java.util.Locale;
public enum TargetType {
REGRESSION, CLASSIFICATION;
public static TargetType fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}
@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}
}

View File

@ -0,0 +1,188 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.client.ml.inference.NamedXContentObjectHelper;
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
public class Ensemble implements TrainedModel {
public static final String NAME = "ensemble";
public static final ParseField FEATURE_NAMES = new ParseField("feature_names");
public static final ParseField TRAINED_MODELS = new ParseField("trained_models");
public static final ParseField AGGREGATE_OUTPUT = new ParseField("aggregate_output");
public static final ParseField TARGET_TYPE = new ParseField("target_type");
public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(
NAME,
true,
Ensemble.Builder::new);
static {
PARSER.declareStringArray(Ensemble.Builder::setFeatureNames, FEATURE_NAMES);
PARSER.declareNamedObjects(Ensemble.Builder::setTrainedModels,
(p, c, n) ->
p.namedObject(TrainedModel.class, n, null),
(ensembleBuilder) -> { /* Noop does not matter client side */ },
TRAINED_MODELS);
PARSER.declareNamedObjects(Ensemble.Builder::setOutputAggregatorFromParser,
(p, c, n) -> p.namedObject(OutputAggregator.class, n, null),
(ensembleBuilder) -> { /* Noop does not matter client side */ },
AGGREGATE_OUTPUT);
PARSER.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE);
PARSER.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
}
public static Ensemble fromXContent(XContentParser parser) {
return PARSER.apply(parser, null).build();
}
private final List<String> featureNames;
private final List<TrainedModel> models;
private final OutputAggregator outputAggregator;
private final TargetType targetType;
private final List<String> classificationLabels;
Ensemble(List<String> featureNames,
List<TrainedModel> models,
@Nullable OutputAggregator outputAggregator,
TargetType targetType,
@Nullable List<String> classificationLabels) {
this.featureNames = featureNames;
this.models = models;
this.outputAggregator = outputAggregator;
this.targetType = targetType;
this.classificationLabels = classificationLabels;
}
@Override
public List<String> getFeatureNames() {
return featureNames;
}
@Override
public String getName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
if (featureNames != null) {
builder.field(FEATURE_NAMES.getPreferredName(), featureNames);
}
if (models != null) {
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, TRAINED_MODELS.getPreferredName(), models);
}
if (outputAggregator != null) {
NamedXContentObjectHelper.writeNamedObjects(builder,
params,
false,
AGGREGATE_OUTPUT.getPreferredName(),
Collections.singletonList(outputAggregator));
}
if (targetType != null) {
builder.field(TARGET_TYPE.getPreferredName(), targetType);
}
if (classificationLabels != null) {
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Ensemble that = (Ensemble) o;
return Objects.equals(featureNames, that.featureNames)
&& Objects.equals(models, that.models)
&& Objects.equals(targetType, that.targetType)
&& Objects.equals(classificationLabels, that.classificationLabels)
&& Objects.equals(outputAggregator, that.outputAggregator);
}
@Override
public int hashCode() {
return Objects.hash(featureNames, models, outputAggregator, classificationLabels, targetType);
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private List<String> featureNames;
private List<TrainedModel> trainedModels;
private OutputAggregator outputAggregator;
private TargetType targetType;
private List<String> classificationLabels;
public Builder setFeatureNames(List<String> featureNames) {
this.featureNames = featureNames;
return this;
}
public Builder setTrainedModels(List<TrainedModel> trainedModels) {
this.trainedModels = trainedModels;
return this;
}
public Builder setOutputAggregator(OutputAggregator outputAggregator) {
this.outputAggregator = outputAggregator;
return this;
}
public Builder setTargetType(TargetType targetType) {
this.targetType = targetType;
return this;
}
public Builder setClassificationLabels(List<String> classificationLabels) {
this.classificationLabels = classificationLabels;
return this;
}
private void setOutputAggregatorFromParser(List<OutputAggregator> outputAggregators) {
this.setOutputAggregator(outputAggregators.get(0));
}
private void setTargetType(String targetType) {
this.targetType = TargetType.fromString(targetType);
}
public Ensemble build() {
return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels);
}
}
}

View File

@ -0,0 +1,28 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.client.ml.inference.NamedXContentObject;
public interface OutputAggregator extends NamedXContentObject {
/**
* @return The name of the output aggregator
*/
String getName();
}

View File

@ -0,0 +1,84 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
public class WeightedMode implements OutputAggregator {
public static final String NAME = "weighted_mode";
public static final ParseField WEIGHTS = new ParseField("weights");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<WeightedMode, Void> PARSER = new ConstructingObjectParser<>(
NAME,
true,
a -> new WeightedMode((List<Double>)a[0]));
static {
PARSER.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
}
public static WeightedMode fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final List<Double> weights;
public WeightedMode(List<Double> weights) {
this.weights = weights;
}
@Override
public String getName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
if (weights != null) {
builder.field(WEIGHTS.getPreferredName(), weights);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
WeightedMode that = (WeightedMode) o;
return Objects.equals(weights, that.weights);
}
@Override
public int hashCode() {
return Objects.hash(weights);
}
}

View File

@ -0,0 +1,84 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
public class WeightedSum implements OutputAggregator {
public static final String NAME = "weighted_sum";
public static final ParseField WEIGHTS = new ParseField("weights");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<WeightedSum, Void> PARSER = new ConstructingObjectParser<>(
NAME,
true,
a -> new WeightedSum((List<Double>)a[0]));
static {
PARSER.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
}
public static WeightedSum fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final List<Double> weights;
public WeightedSum(List<Double> weights) {
this.weights = weights;
}
@Override
public String getName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
if (weights != null) {
builder.field(WEIGHTS.getPreferredName(), weights);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
WeightedSum that = (WeightedSum) o;
return Objects.equals(weights, that.weights);
}
@Override
public int hashCode() {
return Objects.hash(weights);
}
}

View File

@ -18,7 +18,9 @@
*/
package org.elasticsearch.client.ml.inference.trainedmodel.tree;
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ObjectParser;
@ -28,7 +30,6 @@ import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
@ -39,12 +40,16 @@ public class Tree implements TrainedModel {
public static final ParseField FEATURE_NAMES = new ParseField("feature_names");
public static final ParseField TREE_STRUCTURE = new ParseField("tree_structure");
public static final ParseField TARGET_TYPE = new ParseField("target_type");
public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME, true, Builder::new);
static {
PARSER.declareStringArray(Builder::setFeatureNames, FEATURE_NAMES);
PARSER.declareObjectArray(Builder::setNodes, (p, c) -> TreeNode.fromXContent(p), TREE_STRUCTURE);
PARSER.declareString(Builder::setTargetType, TARGET_TYPE);
PARSER.declareStringArray(Builder::setClassificationLabels, CLASSIFICATION_LABELS);
}
public static Tree fromXContent(XContentParser parser) {
@ -53,10 +58,14 @@ public class Tree implements TrainedModel {
private final List<String> featureNames;
private final List<TreeNode> nodes;
private final TargetType targetType;
private final List<String> classificationLabels;
Tree(List<String> featureNames, List<TreeNode> nodes) {
this.featureNames = Collections.unmodifiableList(Objects.requireNonNull(featureNames));
this.nodes = Collections.unmodifiableList(Objects.requireNonNull(nodes));
Tree(List<String> featureNames, List<TreeNode> nodes, TargetType targetType, List<String> classificationLabels) {
this.featureNames = featureNames;
this.nodes = nodes;
this.targetType = targetType;
this.classificationLabels = classificationLabels;
}
@Override
@ -73,11 +82,30 @@ public class Tree implements TrainedModel {
return nodes;
}
@Nullable
public List<String> getClassificationLabels() {
return classificationLabels;
}
public TargetType getTargetType() {
return targetType;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FEATURE_NAMES.getPreferredName(), featureNames);
builder.field(TREE_STRUCTURE.getPreferredName(), nodes);
if (featureNames != null) {
builder.field(FEATURE_NAMES.getPreferredName(), featureNames);
}
if (nodes != null) {
builder.field(TREE_STRUCTURE.getPreferredName(), nodes);
}
if (classificationLabels != null) {
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
}
if (targetType != null) {
builder.field(TARGET_TYPE.getPreferredName(), targetType.toString());
}
builder.endObject();
return builder;
}
@ -93,12 +121,14 @@ public class Tree implements TrainedModel {
if (o == null || getClass() != o.getClass()) return false;
Tree that = (Tree) o;
return Objects.equals(featureNames, that.featureNames)
&& Objects.equals(classificationLabels, that.classificationLabels)
&& Objects.equals(targetType, that.targetType)
&& Objects.equals(nodes, that.nodes);
}
@Override
public int hashCode() {
return Objects.hash(featureNames, nodes);
return Objects.hash(featureNames, nodes, targetType, classificationLabels);
}
public static Builder builder() {
@ -109,6 +139,8 @@ public class Tree implements TrainedModel {
private List<String> featureNames;
private ArrayList<TreeNode.Builder> nodes;
private int numNodes;
private TargetType targetType;
private List<String> classificationLabels;
public Builder() {
nodes = new ArrayList<>();
@ -137,6 +169,20 @@ public class Tree implements TrainedModel {
return setNodes(Arrays.asList(nodes));
}
public Builder setTargetType(TargetType targetType) {
this.targetType = targetType;
return this;
}
public Builder setClassificationLabels(List<String> classificationLabels) {
this.classificationLabels = classificationLabels;
return this;
}
private void setTargetType(String targetType) {
this.targetType = TargetType.fromString(targetType);
}
/**
* Add a decision node. Space for the child nodes is allocated
* @param nodeIndex Where to place the node. This is either 0 (root) or an existing child node index
@ -185,7 +231,9 @@ public class Tree implements TrainedModel {
public Tree build() {
return new Tree(featureNames,
nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList()));
nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList()),
targetType,
classificationLabels);
}
}

View File

@ -65,6 +65,9 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Binar
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
@ -681,7 +684,7 @@ public class RestHighLevelClientTests extends ESTestCase {
public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(41, namedXContents.size());
assertEquals(44, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@ -691,7 +694,7 @@ public class RestHighLevelClientTests extends ESTestCase {
categories.put(namedXContent.categoryClass, counter + 1);
}
}
assertEquals("Had: " + categories, 11, categories.size());
assertEquals("Had: " + categories, 12, categories.size());
assertEquals(Integer.valueOf(3), categories.get(Aggregation.class));
assertTrue(names.contains(ChildrenAggregationBuilder.NAME));
assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME));
@ -740,8 +743,11 @@ public class RestHighLevelClientTests extends ESTestCase {
RSquaredMetric.NAME));
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME));
assertEquals(Integer.valueOf(1), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
assertThat(names, hasItems(Tree.NAME));
assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
assertThat(names, hasItems(Tree.NAME, Ensemble.NAME));
assertEquals(Integer.valueOf(2),
categories.get(org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator.class));
assertThat(names, hasItems(WeightedMode.NAME, WeightedSum.NAME));
}
public void testApiNamingConventions() throws Exception {

View File

@ -0,0 +1,97 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeTests;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class EnsembleTests extends AbstractXContentTestCase<Ensemble> {
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return field -> !field.isEmpty();
}
@Override
protected Ensemble doParseInstance(XContentParser parser) throws IOException {
return Ensemble.fromXContent(parser);
}
public static Ensemble createRandom() {
int numberOfFeatures = randomIntBetween(1, 10);
List<String> featureNames = Stream.generate(() -> randomAlphaOfLength(10))
.limit(numberOfFeatures)
.collect(Collectors.toList());
int numberOfModels = randomIntBetween(1, 10);
List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6))
.limit(numberOfFeatures)
.collect(Collectors.toList());
OutputAggregator outputAggregator = null;
if (randomBoolean()) {
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights));
}
List<String> categoryLabels = null;
if (randomBoolean()) {
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
}
return new Ensemble(featureNames,
models,
outputAggregator,
randomFrom(TargetType.values()),
categoryLabels);
}
@Override
protected Ensemble createTestInstance() {
return createRandom();
}
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
return new NamedXContentRegistry(namedXContent);
}
}

View File

@ -0,0 +1,51 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class WeightedModeTests extends AbstractXContentTestCase<WeightedMode> {
WeightedMode createTestInstance(int numberOfWeights) {
return new WeightedMode(Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()));
}
@Override
protected WeightedMode doParseInstance(XContentParser parser) throws IOException {
return WeightedMode.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected WeightedMode createTestInstance() {
return randomBoolean() ? new WeightedMode(null) : createTestInstance(randomIntBetween(1, 100));
}
}

View File

@ -0,0 +1,51 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class WeightedSumTests extends AbstractXContentTestCase<WeightedSum> {
WeightedSum createTestInstance(int numberOfWeights) {
return new WeightedSum(Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()));
}
@Override
protected WeightedSum doParseInstance(XContentParser parser) throws IOException {
return WeightedSum.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected WeightedSum createTestInstance() {
return randomBoolean() ? new WeightedSum(null) : createTestInstance(randomIntBetween(1, 100));
}
}

View File

@ -18,6 +18,7 @@
*/
package org.elasticsearch.client.ml.inference.trainedmodel.tree;
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
@ -51,16 +52,17 @@ public class TreeTests extends AbstractXContentTestCase<Tree> {
}
public static Tree createRandom() {
return buildRandomTree(randomIntBetween(2, 15), 6);
}
public static Tree buildRandomTree(int numFeatures, int depth) {
Tree.Builder builder = Tree.builder();
List<String> featureNames = new ArrayList<>(numFeatures);
for(int i = 0; i < numFeatures; i++) {
int numberOfFeatures = randomIntBetween(1, 10);
List<String> featureNames = new ArrayList<>();
for (int i = 0; i < numberOfFeatures; i++) {
featureNames.add(randomAlphaOfLength(10));
}
return buildRandomTree(featureNames, 6);
}
public static Tree buildRandomTree(List<String> featureNames, int depth) {
int numFeatures = featureNames.size();
Tree.Builder builder = Tree.builder();
builder.setFeatureNames(featureNames);
TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble());
@ -81,8 +83,13 @@ public class TreeTests extends AbstractXContentTestCase<Tree> {
}
childNodes = nextNodes;
}
return builder.build();
List<String> categoryLabels = null;
if (randomBoolean()) {
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
}
return builder.setClassificationLabels(categoryLabels)
.setTargetType(randomFrom(TargetType.values()))
.build();
}
}

View File

@ -11,6 +11,12 @@ import org.elasticsearch.plugins.spi.NamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LenientlyParsedOutputAggregator;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.StrictlyParsedOutputAggregator;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
@ -46,9 +52,27 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
// Model Lenient
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentLenient));
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Ensemble.NAME, Ensemble::fromXContentLenient));
// Output Aggregator Lenient
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedOutputAggregator.class,
WeightedMode.NAME,
WeightedMode::fromXContentLenient));
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedOutputAggregator.class,
WeightedSum.NAME,
WeightedSum::fromXContentLenient));
// Model Strict
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentStrict));
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedTrainedModel.class, Ensemble.NAME, Ensemble::fromXContentStrict));
// Output Aggregator Strict
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedOutputAggregator.class,
WeightedMode.NAME,
WeightedMode::fromXContentStrict));
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedOutputAggregator.class,
WeightedSum.NAME,
WeightedSum::fromXContentStrict));
return namedXContent;
}
@ -66,6 +90,15 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
// Model
namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModel.class, Tree.NAME.getPreferredName(), Tree::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModel.class, Ensemble.NAME.getPreferredName(), Ensemble::new));
// Output Aggregator
namedWriteables.add(new NamedWriteableRegistry.Entry(OutputAggregator.class,
WeightedSum.NAME.getPreferredName(),
WeightedSum::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(OutputAggregator.class,
WeightedMode.NAME.getPreferredName(),
WeightedMode::new));
return namedWriteables;
}

View File

@ -0,0 +1,36 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import java.io.IOException;
import java.util.Locale;
public enum TargetType implements Writeable {
REGRESSION, CLASSIFICATION;
public static TargetType fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}
public static TargetType fromStream(StreamInput in) throws IOException {
return in.readEnum(TargetType.class);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeEnum(this);
}
@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}
}

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
@ -28,17 +29,47 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable {
double infer(Map<String, Object> fields);
/**
* @return {@code true} if the model is classification, {@code false} otherwise.
* @param fields similar to {@link TrainedModel#infer(Map)}, but fields are already in order and doubles
* @return The predicted value.
*/
boolean isClassification();
double infer(List<Double> fields);
/**
* @return {@link TargetType} for the model.
*/
TargetType targetType();
/**
* This gathers the probabilities for each potential classification value.
*
* The probabilities are indexed by classification ordinal label encoding.
* The length of this list is equal to the number of classification labels.
*
* This only should return if the implementation model is inferring classification values and not regression
* @param fields The fields and their values to infer against
* @return The probabilities of each classification value
*/
List<Double> inferProbabilities(Map<String, Object> fields);
List<Double> classificationProbability(Map<String, Object> fields);
/**
* @param fields similar to {@link TrainedModel#classificationProbability(Map)} but the fields are already in order and doubles
* @return The probabilities of each classification value
*/
List<Double> classificationProbability(List<Double> fields);
/**
* The ordinal encoded list of the classification labels.
* @return Oridinal encoded list of classification labels.
*/
@Nullable
List<String> classificationLabels();
/**
* Runs validations against the model.
*
* Example: {@link org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree} should check if there are any loops
*
* @throws org.elasticsearch.ElasticsearchException if validations fail
*/
void validate();
}

View File

@ -0,0 +1,311 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel {
// TODO should we have regression/classification sub-classes that accept the builder?
public static final ParseField NAME = new ParseField("ensemble");
public static final ParseField FEATURE_NAMES = new ParseField("feature_names");
public static final ParseField TRAINED_MODELS = new ParseField("trained_models");
public static final ParseField AGGREGATE_OUTPUT = new ParseField("aggregate_output");
public static final ParseField TARGET_TYPE = new ParseField("target_type");
public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
private static final ObjectParser<Ensemble.Builder, Void> LENIENT_PARSER = createParser(true);
private static final ObjectParser<Ensemble.Builder, Void> STRICT_PARSER = createParser(false);
private static ObjectParser<Ensemble.Builder, Void> createParser(boolean lenient) {
ObjectParser<Ensemble.Builder, Void> parser = new ObjectParser<>(
NAME.getPreferredName(),
lenient,
Ensemble.Builder::builderForParser);
parser.declareStringArray(Ensemble.Builder::setFeatureNames, FEATURE_NAMES);
parser.declareNamedObjects(Ensemble.Builder::setTrainedModels,
(p, c, n) ->
lenient ? p.namedObject(LenientlyParsedTrainedModel.class, n, null) :
p.namedObject(StrictlyParsedTrainedModel.class, n, null),
(ensembleBuilder) -> ensembleBuilder.setModelsAreOrdered(true),
TRAINED_MODELS);
parser.declareNamedObjects(Ensemble.Builder::setOutputAggregatorFromParser,
(p, c, n) ->
lenient ? p.namedObject(LenientlyParsedOutputAggregator.class, n, null) :
p.namedObject(StrictlyParsedOutputAggregator.class, n, null),
(ensembleBuilder) -> {/*Noop as it could be an array or object, it just has to be a one*/},
AGGREGATE_OUTPUT);
parser.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE);
parser.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
return parser;
}
public static Ensemble fromXContentStrict(XContentParser parser) {
return STRICT_PARSER.apply(parser, null).build();
}
public static Ensemble fromXContentLenient(XContentParser parser) {
return LENIENT_PARSER.apply(parser, null).build();
}
private final List<String> featureNames;
private final List<TrainedModel> models;
private final OutputAggregator outputAggregator;
private final TargetType targetType;
private final List<String> classificationLabels;
Ensemble(List<String> featureNames,
List<TrainedModel> models,
OutputAggregator outputAggregator,
TargetType targetType,
@Nullable List<String> classificationLabels) {
this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
this.models = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(models, TRAINED_MODELS));
this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT);
this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
}
public Ensemble(StreamInput in) throws IOException {
this.featureNames = Collections.unmodifiableList(in.readStringList());
this.models = Collections.unmodifiableList(in.readNamedWriteableList(TrainedModel.class));
this.outputAggregator = in.readNamedWriteable(OutputAggregator.class);
this.targetType = TargetType.fromStream(in);
if (in.readBoolean()) {
this.classificationLabels = in.readStringList();
} else {
this.classificationLabels = null;
}
}
@Override
public List<String> getFeatureNames() {
return featureNames;
}
@Override
public double infer(Map<String, Object> fields) {
List<Double> features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList());
return infer(features);
}
@Override
public double infer(List<Double> fields) {
List<Double> processedInferences = inferAndProcess(fields);
return outputAggregator.aggregate(processedInferences);
}
@Override
public TargetType targetType() {
return targetType;
}
@Override
public List<Double> classificationProbability(Map<String, Object> fields) {
if ((targetType == TargetType.CLASSIFICATION) == false) {
throw new UnsupportedOperationException(
"Cannot determine classification probability with target_type [" + targetType.toString() + "]");
}
List<Double> features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList());
return classificationProbability(features);
}
@Override
public List<Double> classificationProbability(List<Double> fields) {
if ((targetType == TargetType.CLASSIFICATION) == false) {
throw new UnsupportedOperationException(
"Cannot determine classification probability with target_type [" + targetType.toString() + "]");
}
return inferAndProcess(fields);
}
@Override
public List<String> classificationLabels() {
return classificationLabels;
}
private List<Double> inferAndProcess(List<Double> fields) {
List<Double> modelInferences = models.stream().map(m -> m.infer(fields)).collect(Collectors.toList());
return outputAggregator.processValues(modelInferences);
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeStringCollection(featureNames);
out.writeNamedWriteableList(models);
out.writeNamedWriteable(outputAggregator);
targetType.writeTo(out);
out.writeBoolean(classificationLabels != null);
if (classificationLabels != null) {
out.writeStringCollection(classificationLabels);
}
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FEATURE_NAMES.getPreferredName(), featureNames);
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, TRAINED_MODELS.getPreferredName(), models);
NamedXContentObjectHelper.writeNamedObjects(builder,
params,
false,
AGGREGATE_OUTPUT.getPreferredName(),
Collections.singletonList(outputAggregator));
builder.field(TARGET_TYPE.getPreferredName(), targetType.toString());
if (classificationLabels != null) {
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Ensemble that = (Ensemble) o;
return Objects.equals(featureNames, that.featureNames)
&& Objects.equals(models, that.models)
&& Objects.equals(targetType, that.targetType)
&& Objects.equals(classificationLabels, that.classificationLabels)
&& Objects.equals(outputAggregator, that.outputAggregator);
}
@Override
public int hashCode() {
return Objects.hash(featureNames, models, outputAggregator, targetType, classificationLabels);
}
@Override
public void validate() {
if (this.featureNames != null) {
if (this.models.stream()
.anyMatch(trainedModel -> trainedModel.getFeatureNames().equals(this.featureNames) == false)) {
throw ExceptionsHelper.badRequestException(
"[{}] must be the same and in the same order for each of the {}",
FEATURE_NAMES.getPreferredName(),
TRAINED_MODELS.getPreferredName());
}
}
if (outputAggregator.expectedValueSize() != null &&
outputAggregator.expectedValueSize() != models.size()) {
throw ExceptionsHelper.badRequestException(
"[{}] expects value array of size [{}] but number of models is [{}]",
AGGREGATE_OUTPUT.getPreferredName(),
outputAggregator.expectedValueSize(),
models.size());
}
if ((this.targetType == TargetType.CLASSIFICATION) != (this.classificationLabels != null)) {
throw ExceptionsHelper.badRequestException(
"[target_type] should be [classification] if [classification_labels] is provided, and vice versa");
}
this.models.forEach(TrainedModel::validate);
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private List<String> featureNames;
private List<TrainedModel> trainedModels;
private OutputAggregator outputAggregator = new WeightedSum();
private TargetType targetType = TargetType.REGRESSION;
private List<String> classificationLabels;
private boolean modelsAreOrdered;
private Builder (boolean modelsAreOrdered) {
this.modelsAreOrdered = modelsAreOrdered;
}
private static Builder builderForParser() {
return new Builder(false);
}
public Builder() {
this(true);
}
public Builder setFeatureNames(List<String> featureNames) {
this.featureNames = featureNames;
return this;
}
public Builder setTrainedModels(List<TrainedModel> trainedModels) {
this.trainedModels = trainedModels;
return this;
}
public Builder setOutputAggregator(OutputAggregator outputAggregator) {
this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT);
return this;
}
public Builder setTargetType(TargetType targetType) {
this.targetType = targetType;
return this;
}
public Builder setClassificationLabels(List<String> classificationLabels) {
this.classificationLabels = classificationLabels;
return this;
}
private void setOutputAggregatorFromParser(List<OutputAggregator> outputAggregators) {
if (outputAggregators.size() != 1) {
throw ExceptionsHelper.badRequestException("[{}] must have exactly one aggregator defined.",
AGGREGATE_OUTPUT.getPreferredName());
}
this.setOutputAggregator(outputAggregators.get(0));
}
private void setTargetType(String targetType) {
this.targetType = TargetType.fromString(targetType);
}
private void setModelsAreOrdered(boolean value) {
this.modelsAreOrdered = value;
}
public Ensemble build() {
// This is essentially a serialization error but the underlying xcontent parsing does not allow us to inject this requirement
// So, we verify the models were parsed in an ordered fashion here instead.
if (modelsAreOrdered == false && trainedModels != null && trainedModels.size() > 1) {
throw ExceptionsHelper.badRequestException("[trained_models] needs to be an array of objects");
}
return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels);
}
}
}

View File

@ -0,0 +1,10 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
public interface LenientlyParsedOutputAggregator extends OutputAggregator {
}

View File

@ -0,0 +1,47 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
import java.util.List;
public interface OutputAggregator extends NamedXContentObject, NamedWriteable {
/**
* @return The expected size of the values array when aggregating. `null` implies there is no expected size.
*/
Integer expectedValueSize();
/**
* This pre-processes the values so that they may be passed directly to the {@link OutputAggregator#aggregate(List)} method.
*
* Two major types of pre-processed values could be returned:
* - The confidence/probability scaled values given the input values (See: {@link WeightedMode#processValues(List)}
* - A simple transformation of the passed values in preparation for aggregation (See: {@link WeightedSum#processValues(List)}
* @param values the values to process
* @return A new list containing the processed values or the same list if no processing is required
*/
List<Double> processValues(List<Double> values);
/**
* Function to aggregate the processed values into a single double
*
* This may be as simple as returning the index of the maximum value.
*
* Or as complex as a mathematical reduction of all the passed values (i.e. summation, average, etc.).
*
* @param processedValues The values to aggregate
* @return the aggregated value.
*/
double aggregate(List<Double> processedValues);
/**
* @return The name of the output aggregator
*/
String getName();
}

View File

@ -0,0 +1,10 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
public interface StrictlyParsedOutputAggregator extends OutputAggregator {
}

View File

@ -0,0 +1,161 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax;
public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
public static final ParseField NAME = new ParseField("weighted_mode");
public static final ParseField WEIGHTS = new ParseField("weights");
private static final ConstructingObjectParser<WeightedMode, Void> LENIENT_PARSER = createParser(true);
private static final ConstructingObjectParser<WeightedMode, Void> STRICT_PARSER = createParser(false);
@SuppressWarnings("unchecked")
private static ConstructingObjectParser<WeightedMode, Void> createParser(boolean lenient) {
ConstructingObjectParser<WeightedMode, Void> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(),
lenient,
a -> new WeightedMode((List<Double>)a[0]));
parser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
return parser;
}
public static WeightedMode fromXContentStrict(XContentParser parser) {
return STRICT_PARSER.apply(parser, null);
}
public static WeightedMode fromXContentLenient(XContentParser parser) {
return LENIENT_PARSER.apply(parser, null);
}
private final List<Double> weights;
WeightedMode() {
this.weights = null;
}
public WeightedMode(List<Double> weights) {
this.weights = weights == null ? null : Collections.unmodifiableList(weights);
}
public WeightedMode(StreamInput in) throws IOException {
if (in.readBoolean()) {
this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble));
} else {
this.weights = null;
}
}
@Override
public Integer expectedValueSize() {
return this.weights == null ? null : this.weights.size();
}
@Override
public List<Double> processValues(List<Double> values) {
Objects.requireNonNull(values, "values must not be null");
if (weights != null && values.size() != weights.size()) {
throw new IllegalArgumentException("values must be the same length as weights.");
}
List<Integer> freqArray = new ArrayList<>();
Integer maxVal = 0;
for (Double value : values) {
if (value == null) {
throw new IllegalArgumentException("values must not contain null values");
}
if (Double.isNaN(value) || Double.isInfinite(value) || value < 0.0 || value != Math.rint(value)) {
throw new IllegalArgumentException("values must be whole, non-infinite, and positive");
}
Integer integerValue = value.intValue();
freqArray.add(integerValue);
if (integerValue > maxVal) {
maxVal = integerValue;
}
}
List<Double> frequencies = new ArrayList<>(Collections.nCopies(maxVal + 1, Double.NEGATIVE_INFINITY));
for (int i = 0; i < freqArray.size(); i++) {
Double weight = weights == null ? 1.0 : weights.get(i);
Integer value = freqArray.get(i);
Double frequency = frequencies.get(value) == Double.NEGATIVE_INFINITY ? weight : frequencies.get(value) + weight;
frequencies.set(value, frequency);
}
return softMax(frequencies);
}
@Override
public double aggregate(List<Double> values) {
Objects.requireNonNull(values, "values must not be null");
int bestValue = 0;
double bestFreq = Double.NEGATIVE_INFINITY;
for (int i = 0; i < values.size(); i++) {
if (values.get(i) == null) {
throw new IllegalArgumentException("values must not contain null values");
}
if (values.get(i) > bestFreq) {
bestFreq = values.get(i);
bestValue = i;
}
}
return bestValue;
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(weights != null);
if (weights != null) {
out.writeCollection(weights, StreamOutput::writeDouble);
}
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (weights != null) {
builder.field(WEIGHTS.getPreferredName(), weights);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
WeightedMode that = (WeightedMode) o;
return Objects.equals(weights, that.weights);
}
@Override
public int hashCode() {
return Objects.hash(weights);
}
}

View File

@ -0,0 +1,138 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
public static final ParseField NAME = new ParseField("weighted_sum");
public static final ParseField WEIGHTS = new ParseField("weights");
private static final ConstructingObjectParser<WeightedSum, Void> LENIENT_PARSER = createParser(true);
private static final ConstructingObjectParser<WeightedSum, Void> STRICT_PARSER = createParser(false);
@SuppressWarnings("unchecked")
private static ConstructingObjectParser<WeightedSum, Void> createParser(boolean lenient) {
ConstructingObjectParser<WeightedSum, Void> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(),
lenient,
a -> new WeightedSum((List<Double>)a[0]));
parser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
return parser;
}
public static WeightedSum fromXContentStrict(XContentParser parser) {
return STRICT_PARSER.apply(parser, null);
}
public static WeightedSum fromXContentLenient(XContentParser parser) {
return LENIENT_PARSER.apply(parser, null);
}
private final List<Double> weights;
WeightedSum() {
this.weights = null;
}
public WeightedSum(List<Double> weights) {
this.weights = weights == null ? null : Collections.unmodifiableList(weights);
}
public WeightedSum(StreamInput in) throws IOException {
if (in.readBoolean()) {
this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble));
} else {
this.weights = null;
}
}
@Override
public List<Double> processValues(List<Double> values) {
Objects.requireNonNull(values, "values must not be null");
if (weights == null) {
return values;
}
if (values.size() != weights.size()) {
throw new IllegalArgumentException("values must be the same length as weights.");
}
return IntStream.range(0, weights.size()).mapToDouble(i -> values.get(i) * weights.get(i)).boxed().collect(Collectors.toList());
}
@Override
public double aggregate(List<Double> values) {
Objects.requireNonNull(values, "values must not be null");
if (values.isEmpty()) {
throw new IllegalArgumentException("values must not be empty");
}
Optional<Double> summation = values.stream().reduce(Double::sum);
if (summation.isPresent()) {
return summation.get();
}
throw new IllegalArgumentException("values must not contain null values");
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(weights != null);
if (weights != null) {
out.writeCollection(weights, StreamOutput::writeDouble);
}
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (weights != null) {
builder.field(WEIGHTS.getPreferredName(), weights);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
WeightedSum that = (WeightedSum) o;
return Objects.equals(weights, that.weights);
}
@Override
public int hashCode() {
return Objects.hash(weights);
}
@Override
public Integer expectedValueSize() {
return weights == null ? null : this.weights.size();
}
}

View File

@ -9,11 +9,13 @@ import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.CachedSupplier;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
@ -31,10 +33,13 @@ import java.util.stream.Collectors;
public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel {
// TODO should we have regression/classification sub-classes that accept the builder?
public static final ParseField NAME = new ParseField("tree");
public static final ParseField FEATURE_NAMES = new ParseField("feature_names");
public static final ParseField TREE_STRUCTURE = new ParseField("tree_structure");
public static final ParseField TARGET_TYPE = new ParseField("target_type");
public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
private static final ObjectParser<Tree.Builder, Void> LENIENT_PARSER = createParser(true);
private static final ObjectParser<Tree.Builder, Void> STRICT_PARSER = createParser(false);
@ -46,6 +51,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
Tree.Builder::new);
parser.declareStringArray(Tree.Builder::setFeatureNames, FEATURE_NAMES);
parser.declareObjectArray(Tree.Builder::setNodes, (p, c) -> TreeNode.fromXContent(p, lenient), TREE_STRUCTURE);
parser.declareString(Tree.Builder::setTargetType, TARGET_TYPE);
parser.declareStringArray(Tree.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
return parser;
}
@ -59,15 +66,28 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
private final List<String> featureNames;
private final List<TreeNode> nodes;
private final TargetType targetType;
private final List<String> classificationLabels;
private final CachedSupplier<Double> highestOrderCategory;
Tree(List<String> featureNames, List<TreeNode> nodes) {
Tree(List<String> featureNames, List<TreeNode> nodes, TargetType targetType, List<String> classificationLabels) {
this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
this.nodes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE));
this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue());
}
public Tree(StreamInput in) throws IOException {
this.featureNames = Collections.unmodifiableList(in.readStringList());
this.nodes = Collections.unmodifiableList(in.readList(TreeNode::new));
this.targetType = TargetType.fromStream(in);
if (in.readBoolean()) {
this.classificationLabels = Collections.unmodifiableList(in.readStringList());
} else {
this.classificationLabels = null;
}
this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue());
}
@Override
@ -90,7 +110,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
return infer(features);
}
private double infer(List<Double> features) {
@Override
public double infer(List<Double> features) {
TreeNode node = nodes.get(0);
while(node.isLeaf() == false) {
node = nodes.get(node.compare(features));
@ -115,13 +136,40 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
}
@Override
public boolean isClassification() {
return false;
public TargetType targetType() {
return targetType;
}
@Override
public List<Double> inferProbabilities(Map<String, Object> fields) {
throw new UnsupportedOperationException("Cannot infer probabilities against a regression model.");
public List<Double> classificationProbability(Map<String, Object> fields) {
if ((targetType == TargetType.CLASSIFICATION) == false) {
throw new UnsupportedOperationException(
"Cannot determine classification probability with target_type [" + targetType.toString() + "]");
}
return classificationProbability(featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList()));
}
@Override
public List<Double> classificationProbability(List<Double> fields) {
if ((targetType == TargetType.CLASSIFICATION) == false) {
throw new UnsupportedOperationException(
"Cannot determine classification probability with target_type [" + targetType.toString() + "]");
}
double label = infer(fields);
// If we are classification, we should assume that the inference return value is whole.
assert label == Math.rint(label);
double maxCategory = this.highestOrderCategory.get();
// If we are classification, we should assume that the largest leaf value is whole.
assert maxCategory == Math.rint(maxCategory);
List<Double> list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0));
// TODO, eventually have TreeNodes contain confidence levels
list.set(Double.valueOf(label).intValue(), 1.0);
return list;
}
@Override
public List<String> classificationLabels() {
return classificationLabels;
}
@Override
@ -133,6 +181,11 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
public void writeTo(StreamOutput out) throws IOException {
out.writeStringCollection(featureNames);
out.writeCollection(nodes);
targetType.writeTo(out);
out.writeBoolean(classificationLabels != null);
if (classificationLabels != null) {
out.writeStringCollection(classificationLabels);
}
}
@Override
@ -140,6 +193,10 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
builder.startObject();
builder.field(FEATURE_NAMES.getPreferredName(), featureNames);
builder.field(TREE_STRUCTURE.getPreferredName(), nodes);
builder.field(TARGET_TYPE.getPreferredName(), targetType.toString());
if(classificationLabels != null) {
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
}
builder.endObject();
return builder;
}
@ -155,22 +212,96 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
if (o == null || getClass() != o.getClass()) return false;
Tree that = (Tree) o;
return Objects.equals(featureNames, that.featureNames)
&& Objects.equals(nodes, that.nodes);
&& Objects.equals(nodes, that.nodes)
&& Objects.equals(targetType, that.targetType)
&& Objects.equals(classificationLabels, that.classificationLabels);
}
@Override
public int hashCode() {
return Objects.hash(featureNames, nodes);
return Objects.hash(featureNames, nodes, targetType, classificationLabels);
}
public static Builder builder() {
return new Builder();
}
@Override
public void validate() {
checkTargetType();
detectMissingNodes();
detectCycle();
}
private void checkTargetType() {
if ((this.targetType == TargetType.CLASSIFICATION) != (this.classificationLabels != null)) {
throw ExceptionsHelper.badRequestException(
"[target_type] should be [classification] if [classification_labels] is provided, and vice versa");
}
}
private void detectCycle() {
if (nodes.isEmpty()) {
return;
}
Set<Integer> visited = new HashSet<>(nodes.size());
Queue<Integer> toVisit = new ArrayDeque<>(nodes.size());
toVisit.add(0);
while(toVisit.isEmpty() == false) {
Integer nodeIdx = toVisit.remove();
if (visited.contains(nodeIdx)) {
throw ExceptionsHelper.badRequestException("[tree] contains cycle at node {}", nodeIdx);
}
visited.add(nodeIdx);
TreeNode treeNode = nodes.get(nodeIdx);
if (treeNode.getLeftChild() >= 0) {
toVisit.add(treeNode.getLeftChild());
}
if (treeNode.getRightChild() >= 0) {
toVisit.add(treeNode.getRightChild());
}
}
}
private void detectMissingNodes() {
if (nodes.isEmpty()) {
return;
}
List<Integer> missingNodes = new ArrayList<>();
for (int i = 0; i < nodes.size(); i++) {
TreeNode currentNode = nodes.get(i);
if (currentNode == null) {
continue;
}
if (nodeMissing(currentNode.getLeftChild(), nodes)) {
missingNodes.add(currentNode.getLeftChild());
}
if (nodeMissing(currentNode.getRightChild(), nodes)) {
missingNodes.add(currentNode.getRightChild());
}
}
if (missingNodes.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("[tree] contains missing nodes {}", missingNodes);
}
}
private static boolean nodeMissing(int nodeIdx, List<TreeNode> nodes) {
return nodeIdx >= nodes.size();
}
private Double maxLeafValue() {
return targetType == TargetType.CLASSIFICATION ?
this.nodes.stream().filter(TreeNode::isLeaf).mapToDouble(TreeNode::getLeafValue).max().getAsDouble() :
null;
}
public static class Builder {
private List<String> featureNames;
private ArrayList<TreeNode.Builder> nodes;
private int numNodes;
private TargetType targetType = TargetType.REGRESSION;
private List<String> classificationLabels;
public Builder() {
nodes = new ArrayList<>();
@ -185,13 +316,18 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
return this;
}
public Builder setRoot(TreeNode.Builder root) {
nodes.set(0, root);
return this;
}
public Builder addNode(TreeNode.Builder node) {
nodes.add(node);
return this;
}
public Builder setNodes(List<TreeNode.Builder> nodes) {
this.nodes = new ArrayList<>(nodes);
this.nodes = new ArrayList<>(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE.getPreferredName()));
return this;
}
@ -199,6 +335,21 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
return setNodes(Arrays.asList(nodes));
}
public Builder setTargetType(TargetType targetType) {
this.targetType = targetType;
return this;
}
public Builder setClassificationLabels(List<String> classificationLabels) {
this.classificationLabels = classificationLabels;
return this;
}
private void setTargetType(String targetType) {
this.targetType = TargetType.fromString(targetType);
}
/**
* Add a decision node. Space for the child nodes is allocated
* @param nodeIndex Where to place the node. This is either 0 (root) or an existing child node index
@ -231,61 +382,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
return node;
}
void detectCycle(List<TreeNode.Builder> nodes) {
if (nodes.isEmpty()) {
return;
}
Set<Integer> visited = new HashSet<>();
Queue<Integer> toVisit = new ArrayDeque<>(nodes.size());
toVisit.add(0);
while(toVisit.isEmpty() == false) {
Integer nodeIdx = toVisit.remove();
if (visited.contains(nodeIdx)) {
throw new IllegalArgumentException("[tree] contains cycle at node " + nodeIdx);
}
visited.add(nodeIdx);
TreeNode.Builder treeNode = nodes.get(nodeIdx);
if (treeNode.getLeftChild() != null) {
toVisit.add(treeNode.getLeftChild());
}
if (treeNode.getRightChild() != null) {
toVisit.add(treeNode.getRightChild());
}
}
}
void detectNullOrMissingNode(List<TreeNode.Builder> nodes) {
if (nodes.isEmpty()) {
return;
}
if (nodes.get(0) == null) {
throw new IllegalArgumentException("[tree] must have non-null root node.");
}
List<Integer> nullOrMissingNodes = new ArrayList<>();
for (int i = 0; i < nodes.size(); i++) {
TreeNode.Builder currentNode = nodes.get(i);
if (currentNode == null) {
continue;
}
if (nodeNullOrMissing(currentNode.getLeftChild())) {
nullOrMissingNodes.add(currentNode.getLeftChild());
}
if (nodeNullOrMissing(currentNode.getRightChild())) {
nullOrMissingNodes.add(currentNode.getRightChild());
}
}
if (nullOrMissingNodes.isEmpty() == false) {
throw new IllegalArgumentException("[tree] contains null or missing nodes " + nullOrMissingNodes);
}
}
private boolean nodeNullOrMissing(Integer nodeIdx) {
if (nodeIdx == null) {
return false;
}
return nodeIdx >= nodes.size() || nodes.get(nodeIdx) == null;
}
/**
* Sets the node at {@code nodeIndex} to a leaf node.
* @param nodeIndex The index as allocated by a call to {@link #addJunction(int, int, boolean, double)}
@ -301,10 +397,13 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
}
public Tree build() {
detectNullOrMissingNode(nodes);
detectCycle(nodes);
if (nodes.stream().anyMatch(Objects::isNull)) {
throw ExceptionsHelper.badRequestException("[tree] cannot contain null nodes");
}
return new Tree(featureNames,
nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList()));
nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList()),
targetType,
classificationLabels);
}
}

View File

@ -143,7 +143,7 @@ public class TreeNode implements ToXContentObject, Writeable {
}
public boolean isLeaf() {
return leftChild < 1;
return leftChild < 0;
}
public int compare(List<Double> features) {

View File

@ -0,0 +1,52 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.utils;
import java.util.List;
import java.util.stream.Collectors;
public final class Statistics {
private Statistics(){}
/**
* Calculates the softMax of the passed values.
*
* Any {@link Double#isInfinite()}, {@link Double#NaN}, or `null` values are ignored in calculation and returned as 0.0 in the
* softMax.
* @param values Values on which to run SoftMax.
* @return A new list containing the softmax of the passed values
*/
public static List<Double> softMax(List<Double> values) {
Double expSum = 0.0;
Double max = values.stream().filter(v -> isInvalid(v) == false).max(Double::compareTo).orElse(null);
if (max == null) {
throw new IllegalArgumentException("no valid values present");
}
List<Double> exps = values.stream().map(v -> isInvalid(v) ? Double.NEGATIVE_INFINITY : v - max)
.collect(Collectors.toList());
for (int i = 0; i < exps.size(); i++) {
if (isInvalid(exps.get(i)) == false) {
Double exp = Math.exp(exps.get(i));
expSum += exp;
exps.set(i, exp);
}
}
for (int i = 0; i < exps.size(); i++) {
if (isInvalid(exps.get(i))) {
exps.set(i, 0.0);
} else {
exps.set(i, exps.get(i)/expSum);
}
}
return exps;
}
public static boolean isInvalid(Double v) {
return v == null || Double.isInfinite(v) || Double.isNaN(v);
}
}

View File

@ -17,6 +17,7 @@ import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.EnsembleTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
@ -157,7 +158,7 @@ public class NamedXContentObjectsTests extends AbstractXContentTestCase<NamedXCo
NamedObjectContainer container = new NamedObjectContainer();
container.setPreProcessors(preProcessors);
container.setUseExplicitPreprocessorOrder(true);
container.setModel(TreeTests.buildRandomTree(5, 4));
container.setModel(randomFrom(TreeTests.createRandom(), EnsembleTests.createRandom()));
return container;
}

View File

@ -0,0 +1,402 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
private boolean lenient;
@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
}
@Override
protected boolean supportsUnknownFields() {
return lenient;
}
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return field -> !field.isEmpty();
}
@Override
protected Ensemble doParseInstance(XContentParser parser) throws IOException {
return lenient ? Ensemble.fromXContentLenient(parser) : Ensemble.fromXContentStrict(parser);
}
public static Ensemble createRandom() {
int numberOfFeatures = randomIntBetween(1, 10);
List<String> featureNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numberOfFeatures).collect(Collectors.toList());
int numberOfModels = randomIntBetween(1, 10);
List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6))
.limit(numberOfModels)
.collect(Collectors.toList());
List<Double> weights = randomBoolean() ?
null :
Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights));
List<String> categoryLabels = null;
if (randomBoolean()) {
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
}
return new Ensemble(featureNames,
models,
outputAggregator,
randomFrom(TargetType.values()),
categoryLabels);
}
@Override
protected Ensemble createTestInstance() {
return createRandom();
}
@Override
protected Writeable.Reader<Ensemble> instanceReader() {
return Ensemble::new;
}
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
return new NamedXContentRegistry(namedXContent);
}
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
return new NamedWriteableRegistry(entries);
}
public void testEnsembleWithModelsThatHaveDifferentFeatureNames() {
List<String> featureNames = Arrays.asList("foo", "bar", "baz", "farequote");
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
Ensemble.builder().setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(TreeTests.buildRandomTree(Arrays.asList("bar", "foo", "baz", "farequote"), 6)))
.build()
.validate();
});
assertThat(ex.getMessage(), equalTo("[feature_names] must be the same and in the same order for each of the trained_models"));
ex = expectThrows(ElasticsearchException.class, () -> {
Ensemble.builder().setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(TreeTests.buildRandomTree(Arrays.asList("completely_different"), 6)))
.build()
.validate();
});
assertThat(ex.getMessage(), equalTo("[feature_names] must be the same and in the same order for each of the trained_models"));
}
public void testEnsembleWithAggregatedOutputDifferingFromTrainedModels() {
List<String> featureNames = Arrays.asList("foo", "bar");
int numberOfModels = 5;
List<Double> weights = new ArrayList<>(numberOfModels + 2);
for (int i = 0; i < numberOfModels + 2; i++) {
weights.add(randomDouble());
}
OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights));
List<TrainedModel> models = new ArrayList<>(numberOfModels);
for (int i = 0; i < numberOfModels; i++) {
models.add(TreeTests.buildRandomTree(featureNames, 6));
}
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
Ensemble.builder()
.setTrainedModels(models)
.setOutputAggregator(outputAggregator)
.setFeatureNames(featureNames)
.build()
.validate();
});
assertThat(ex.getMessage(), equalTo("[aggregate_output] expects value array of size [7] but number of models is [5]"));
}
public void testEnsembleWithInvalidModel() {
List<String> featureNames = Arrays.asList("foo", "bar");
expectThrows(ElasticsearchException.class, () -> {
Ensemble.builder()
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(
// Tree with loop
Tree.builder()
.setNodes(TreeNode.builder(0)
.setLeftChild(1)
.setSplitFeature(1)
.setThreshold(randomDouble()),
TreeNode.builder(0)
.setLeftChild(0)
.setSplitFeature(1)
.setThreshold(randomDouble()))
.setFeatureNames(featureNames)
.build()))
.build()
.validate();
});
}
public void testEnsembleWithTargetTypeAndLabelsMismatch() {
List<String> featureNames = Arrays.asList("foo", "bar");
String msg = "[target_type] should be [classification] if [classification_labels] is provided, and vice versa";
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
Ensemble.builder()
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(
Tree.builder()
.setNodes(TreeNode.builder(0)
.setLeftChild(1)
.setSplitFeature(1)
.setThreshold(randomDouble()))
.setFeatureNames(featureNames)
.build()))
.setClassificationLabels(Arrays.asList("label1", "label2"))
.build()
.validate();
});
assertThat(ex.getMessage(), equalTo(msg));
ex = expectThrows(ElasticsearchException.class, () -> {
Ensemble.builder()
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(
Tree.builder()
.setNodes(TreeNode.builder(0)
.setLeftChild(1)
.setSplitFeature(1)
.setThreshold(randomDouble()))
.setFeatureNames(featureNames)
.build()))
.setTargetType(TargetType.CLASSIFICATION)
.build()
.validate();
});
assertThat(ex.getMessage(), equalTo(msg));
}
public void testClassificationProbability() {
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(1.0))
.addNode(TreeNode.builder(2)
.setThreshold(0.8)
.setSplitFeature(1)
.setLeftChild(3)
.setRightChild(4))
.addNode(TreeNode.builder(3).setLeafValue(0.0))
.addNode(TreeNode.builder(4).setLeafValue(1.0)).build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(0.0))
.addNode(TreeNode.builder(2).setLeafValue(1.0))
.build();
Tree tree3 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(1)
.setThreshold(1.0))
.addNode(TreeNode.builder(1).setLeafValue(1.0))
.addNode(TreeNode.builder(2).setLeafValue(0.0))
.build();
Ensemble ensemble = Ensemble.builder()
.setTargetType(TargetType.CLASSIFICATION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
.setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0)))
.build();
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
List<Double> expected = Arrays.asList(0.231475216, 0.768524783);
double eps = 0.000001;
List<Double> probabilities = ensemble.classificationProbability(featureMap);
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
}
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
expected = Arrays.asList(0.3100255188, 0.689974481);
probabilities = ensemble.classificationProbability(featureMap);
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
}
featureVector = Arrays.asList(0.0, 1.0);
featureMap = zipObjMap(featureNames, featureVector);
expected = Arrays.asList(0.231475216, 0.768524783);
probabilities = ensemble.classificationProbability(featureMap);
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
}
}
public void testClassificationInference() {
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(1.0))
.addNode(TreeNode.builder(2)
.setThreshold(0.8)
.setSplitFeature(1)
.setLeftChild(3)
.setRightChild(4))
.addNode(TreeNode.builder(3).setLeafValue(0.0))
.addNode(TreeNode.builder(4).setLeafValue(1.0)).build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(0.0))
.addNode(TreeNode.builder(2).setLeafValue(1.0))
.build();
Tree tree3 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(1)
.setThreshold(1.0))
.addNode(TreeNode.builder(1).setLeafValue(1.0))
.addNode(TreeNode.builder(2).setLeafValue(0.0))
.build();
Ensemble ensemble = Ensemble.builder()
.setTargetType(TargetType.CLASSIFICATION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
.setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0)))
.build();
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
featureVector = Arrays.asList(0.0, 1.0);
featureMap = zipObjMap(featureNames, featureVector);
assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
}
public void testRegressionInference() {
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(0.3))
.addNode(TreeNode.builder(2)
.setThreshold(0.8)
.setSplitFeature(1)
.setLeftChild(3)
.setRightChild(4))
.addNode(TreeNode.builder(3).setLeafValue(0.1))
.addNode(TreeNode.builder(4).setLeafValue(0.2)).build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(1.5))
.addNode(TreeNode.builder(2).setLeafValue(0.9))
.build();
Ensemble ensemble = Ensemble.builder()
.setTargetType(TargetType.REGRESSION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2))
.setOutputAggregator(new WeightedSum(Arrays.asList(0.5, 0.5)))
.build();
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertEquals(0.9, ensemble.infer(featureMap), 0.00001);
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertEquals(0.5, ensemble.infer(featureMap), 0.00001);
// Test with NO aggregator supplied, verifies default behavior of non-weighted sum
ensemble = Ensemble.builder()
.setTargetType(TargetType.REGRESSION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2))
.build();
featureVector = Arrays.asList(0.4, 0.0);
featureMap = zipObjMap(featureNames, featureVector);
assertEquals(1.8, ensemble.infer(featureMap), 0.00001);
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
}
private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
}
}

View File

@ -0,0 +1,51 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.junit.Before;
import java.util.ArrayList;
import java.util.List;
import static org.hamcrest.Matchers.equalTo;
public abstract class WeightedAggregatorTests<T extends OutputAggregator> extends AbstractSerializingTestCase<T> {
protected boolean lenient;
@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
}
@Override
protected boolean supportsUnknownFields() {
return lenient;
}
public void testWithNullValues() {
OutputAggregator outputAggregator = createTestInstance();
NullPointerException ex = expectThrows(NullPointerException.class, () -> outputAggregator.processValues(null));
assertThat(ex.getMessage(), equalTo("values must not be null"));
}
public void testWithValuesOfWrongLength() {
int numberOfValues = randomIntBetween(5, 10);
List<Double> values = new ArrayList<>(numberOfValues);
for (int i = 0; i < numberOfValues; i++) {
values.add(randomDouble());
}
OutputAggregator outputAggregatorWithTooFewWeights = createTestInstance(randomIntBetween(1, numberOfValues - 1));
expectThrows(IllegalArgumentException.class, () -> outputAggregatorWithTooFewWeights.processValues(values));
OutputAggregator outputAggregatorWithTooManyWeights = createTestInstance(randomIntBetween(numberOfValues + 1, numberOfValues + 10));
expectThrows(IllegalArgumentException.class, () -> outputAggregatorWithTooManyWeights.processValues(values));
}
abstract T createTestInstance(int numberOfWeights);
}

View File

@ -0,0 +1,58 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.hamcrest.Matchers.equalTo;
public class WeightedModeTests extends WeightedAggregatorTests<WeightedMode> {
@Override
WeightedMode createTestInstance(int numberOfWeights) {
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList());
return new WeightedMode(weights);
}
@Override
protected WeightedMode doParseInstance(XContentParser parser) throws IOException {
return lenient ? WeightedMode.fromXContentLenient(parser) : WeightedMode.fromXContentStrict(parser);
}
@Override
protected WeightedMode createTestInstance() {
return randomBoolean() ? new WeightedMode() : createTestInstance(randomIntBetween(1, 100));
}
@Override
protected Writeable.Reader<WeightedMode> instanceReader() {
return WeightedMode::new;
}
public void testAggregate() {
List<Double> ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0);
List<Double> values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0);
WeightedMode weightedMode = new WeightedMode(ones);
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0));
List<Double> variedWeights = Arrays.asList(1.0, -1.0, .5, 1.0, 5.0);
weightedMode = new WeightedMode(variedWeights);
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(5.0));
weightedMode = new WeightedMode();
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0));
}
}

View File

@ -0,0 +1,58 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.hamcrest.Matchers.equalTo;
public class WeightedSumTests extends WeightedAggregatorTests<WeightedSum> {
@Override
WeightedSum createTestInstance(int numberOfWeights) {
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList());
return new WeightedSum(weights);
}
@Override
protected WeightedSum doParseInstance(XContentParser parser) throws IOException {
return lenient ? WeightedSum.fromXContentLenient(parser) : WeightedSum.fromXContentStrict(parser);
}
@Override
protected WeightedSum createTestInstance() {
return randomBoolean() ? new WeightedSum() : createTestInstance(randomIntBetween(1, 100));
}
@Override
protected Writeable.Reader<WeightedSum> instanceReader() {
return WeightedSum::new;
}
public void testAggregate() {
List<Double> ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0);
List<Double> values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0);
WeightedSum weightedSum = new WeightedSum(ones);
assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0));
List<Double> variedWeights = Arrays.asList(1.0, -1.0, .5, 1.0, 5.0);
weightedSum = new WeightedSum(variedWeights);
assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(28.0));
weightedSum = new WeightedSum();
assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0));
}
}

View File

@ -5,9 +5,12 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.junit.Before;
import java.io.IOException;
@ -47,23 +50,23 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
return field -> field.startsWith("feature_names");
}
@Override
protected Tree createTestInstance() {
return createRandom();
}
public static Tree createRandom() {
return buildRandomTree(randomIntBetween(2, 15), 6);
}
public static Tree buildRandomTree(int numFeatures, int depth) {
Tree.Builder builder = Tree.builder();
List<String> featureNames = new ArrayList<>(numFeatures);
for(int i = 0; i < numFeatures; i++) {
int numberOfFeatures = randomIntBetween(1, 10);
List<String> featureNames = new ArrayList<>();
for (int i = 0; i < numberOfFeatures; i++) {
featureNames.add(randomAlphaOfLength(10));
}
return buildRandomTree(featureNames, 6);
}
public static Tree buildRandomTree(List<String> featureNames, int depth) {
Tree.Builder builder = Tree.builder();
int numFeatures = featureNames.size() - 1;
builder.setFeatureNames(featureNames);
TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble());
@ -84,8 +87,14 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
}
childNodes = nextNodes;
}
List<String> categoryLabels = null;
if (randomBoolean()) {
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
}
return builder.build();
return builder.setTargetType(randomFrom(TargetType.values()))
.setClassificationLabels(categoryLabels)
.build();
}
@Override
@ -96,7 +105,7 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
public void testInfer() {
// Build a tree with 2 nodes and 3 leaves using 2 features
// The leaves have unique values 0.1, 0.2, 0.3
Tree.Builder builder = Tree.builder();
Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
builder.addLeaf(rootNode.getRightChild(), 0.3);
TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
@ -124,37 +133,76 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
assertEquals(0.2, tree.infer(featureMap), 0.00001);
}
public void testTreeClassificationProbability() {
// Build a tree with 2 nodes and 3 leaves using 2 features
// The leaves have unique values 0.1, 0.2, 0.3
Tree.Builder builder = Tree.builder().setTargetType(TargetType.CLASSIFICATION);
TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
builder.addLeaf(rootNode.getRightChild(), 1.0);
TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
builder.addLeaf(leftChildNode.getLeftChild(), 1.0);
builder.addLeaf(leftChildNode.getRightChild(), 0.0);
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree = builder.setFeatureNames(featureNames).build();
// This feature vector should hit the right child of the root node
List<Double> featureVector = Arrays.asList(0.6, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap));
// This should hit the left child of the left child of the root node
// i.e. it takes the path left, left
featureVector = Arrays.asList(0.3, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap));
// This should hit the right child of the left child of the root node
// i.e. it takes the path left, right
featureVector = Arrays.asList(0.3, 0.9);
featureMap = zipObjMap(featureNames, featureVector);
assertEquals(Arrays.asList(1.0, 0.0), tree.classificationProbability(featureMap));
}
public void testTreeWithNullRoot() {
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class,
() -> Tree.builder().setNodes(Collections.singletonList(null))
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class,
() -> Tree.builder()
.setNodes(Collections.singletonList(null))
.setFeatureNames(Arrays.asList("foo", "bar"))
.build());
assertThat(ex.getMessage(), equalTo("[tree] must have non-null root node."));
assertThat(ex.getMessage(), equalTo("[tree] cannot contain null nodes"));
}
public void testTreeWithInvalidNode() {
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class,
() -> Tree.builder().setNodes(TreeNode.builder(0)
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class,
() -> Tree.builder()
.setNodes(TreeNode.builder(0)
.setLeftChild(1)
.setSplitFeature(1)
.setThreshold(randomDouble()))
.build());
assertThat(ex.getMessage(), equalTo("[tree] contains null or missing nodes [1]"));
.setFeatureNames(Arrays.asList("foo", "bar"))
.build().validate());
assertThat(ex.getMessage(), equalTo("[tree] contains missing nodes [1]"));
}
public void testTreeWithNullNode() {
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class,
() -> Tree.builder().setNodes(TreeNode.builder(0)
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class,
() -> Tree.builder()
.setNodes(TreeNode.builder(0)
.setLeftChild(1)
.setSplitFeature(1)
.setThreshold(randomDouble()),
null)
.build());
assertThat(ex.getMessage(), equalTo("[tree] contains null or missing nodes [1]"));
.setFeatureNames(Arrays.asList("foo", "bar"))
.build()
.validate());
assertThat(ex.getMessage(), equalTo("[tree] cannot contain null nodes"));
}
public void testTreeWithCycle() {
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class,
() -> Tree.builder().setNodes(TreeNode.builder(0)
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class,
() -> Tree.builder()
.setNodes(TreeNode.builder(0)
.setLeftChild(1)
.setSplitFeature(1)
.setThreshold(randomDouble()),
@ -162,10 +210,41 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
.setLeftChild(0)
.setSplitFeature(1)
.setThreshold(randomDouble()))
.build());
.setFeatureNames(Arrays.asList("foo", "bar"))
.build()
.validate());
assertThat(ex.getMessage(), equalTo("[tree] contains cycle at node 0"));
}
public void testTreeWithTargetTypeAndLabelsMismatch() {
List<String> featureNames = Arrays.asList("foo", "bar");
String msg = "[target_type] should be [classification] if [classification_labels] is provided, and vice versa";
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
Tree.builder()
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setSplitFeature(1)
.setThreshold(randomDouble()))
.setFeatureNames(featureNames)
.setClassificationLabels(Arrays.asList("label1", "label2"))
.build()
.validate();
});
assertThat(ex.getMessage(), equalTo(msg));
ex = expectThrows(ElasticsearchException.class, () -> {
Tree.builder()
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setSplitFeature(1)
.setThreshold(randomDouble()))
.setFeatureNames(featureNames)
.setTargetType(TargetType.CLASSIFICATION)
.build()
.validate();
});
assertThat(ex.getMessage(), equalTo(msg));
}
private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
}

View File

@ -0,0 +1,33 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.utils;
import org.elasticsearch.test.ESTestCase;
import java.util.Arrays;
import java.util.List;
import static org.hamcrest.Matchers.closeTo;
public class StatisticsTests extends ESTestCase {
public void testSoftMax() {
List<Double> values = Arrays.asList(Double.NEGATIVE_INFINITY, 1.0, -0.5, null, Double.NaN, Double.POSITIVE_INFINITY, 1.0, 5.0);
List<Double> softMax = Statistics.softMax(values);
List<Double> expected = Arrays.asList(0.0, 0.017599040, 0.003926876, 0.0, 0.0, 0.0, 0.017599040, 0.960875042);
for(int i = 0; i < expected.size(); i++) {
assertThat(softMax.get(i), closeTo(expected.get(i), 0.000001));
}
}
public void testSoftMaxWithNoValidValues() {
List<Double> values = Arrays.asList(Double.NEGATIVE_INFINITY, null, Double.NaN, Double.POSITIVE_INFINITY);
expectThrows(IllegalArgumentException.class, () -> Statistics.softMax(values));
}
}