From b9d9964d1026e104ea982bd7a1a91831e44c00d5 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Fri, 3 Jul 2020 14:51:00 -0400 Subject: [PATCH] [ML] add exponent output aggregator to inference (#58933) (#59016) * [ML] add exponent output aggregator to inference * fixing docs Co-authored-by: Elastic Machine --- .../MlInferenceNamedXContentProvider.java | 4 + .../trainedmodel/ensemble/Exponent.java | 83 ++++++++ .../client/RestHighLevelClientTests.java | 7 +- .../trainedmodel/ensemble/EnsembleTests.java | 3 +- .../trainedmodel/ensemble/ExponentTests.java | 51 +++++ .../df-analytics/apis/put-inference.asciidoc | 180 ++++++++++-------- .../MlInferenceNamedXContentProvider.java | 10 + .../trainedmodel/ensemble/Exponent.java | 157 +++++++++++++++ .../trainedmodel/ensemble/EnsembleTests.java | 3 +- .../trainedmodel/ensemble/ExponentTests.java | 69 +++++++ 10 files changed, 487 insertions(+), 80 deletions(-) create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Exponent.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/ExponentTests.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Exponent.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/ExponentTests.java diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java index d8a00f1d8f6..30c7598d72f 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java @@ -24,6 +24,7 @@ import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Exponent; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode; @@ -82,6 +83,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider { namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class, new ParseField(LogisticRegression.NAME), LogisticRegression::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class, + new ParseField(Exponent.NAME), + Exponent::fromXContent)); return namedXContent; } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Exponent.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Exponent.java new file mode 100644 index 00000000000..2247006ed0a --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Exponent.java @@ -0,0 +1,83 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.ensemble; + + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + + +public class Exponent implements OutputAggregator { + + public static final String NAME = "exponent"; + public static final ParseField WEIGHTS = new ParseField("weights"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + true, + a -> new Exponent((List)a[0])); + static { + PARSER.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS); + } + + public static Exponent fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final List weights; + + public Exponent(List weights) { + this.weights = weights; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (weights != null) { + builder.field(WEIGHTS.getPreferredName(), weights); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Exponent that = (Exponent) o; + return Objects.equals(weights, that.weights); + } + + @Override + public int hashCode() { + return Objects.hash(weights); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 835012eabb4..6ee55810a57 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -80,6 +80,7 @@ import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncoding; import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Exponent; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum; @@ -703,7 +704,7 @@ public class RestHighLevelClientTests extends ESTestCase { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(68, namedXContents.size()); + assertEquals(69, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -788,9 +789,9 @@ public class RestHighLevelClientTests extends ESTestCase { assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME)); assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class)); assertThat(names, hasItems(Tree.NAME, Ensemble.NAME, LangIdentNeuralNetwork.NAME)); - assertEquals(Integer.valueOf(3), + assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator.class)); - assertThat(names, hasItems(WeightedMode.NAME, WeightedSum.NAME, LogisticRegression.NAME)); + assertThat(names, hasItems(WeightedMode.NAME, WeightedSum.NAME, LogisticRegression.NAME, Exponent.NAME)); assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig.class)); assertThat(names, hasItems(ClassificationConfig.NAME.getPreferredName(), RegressionConfig.NAME.getPreferredName())); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java index 318d390de7c..507bd770609 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -73,7 +73,8 @@ public class EnsembleTests extends AbstractXContentTestCase { categoryLabels = randomList(2, randomIntBetween(3, 10), () -> randomAlphaOfLength(10)); } List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList()); - OutputAggregator outputAggregator = targetType == TargetType.REGRESSION ? new WeightedSum(weights) : + OutputAggregator outputAggregator = targetType == TargetType.REGRESSION ? + randomFrom(new WeightedSum(weights), new Exponent(weights)) : randomFrom( new WeightedMode( categoryLabels != null ? categoryLabels.size() : randomIntBetween(2, 10), diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/ExponentTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/ExponentTests.java new file mode 100644 index 00000000000..4036637a4ae --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/ExponentTests.java @@ -0,0 +1,51 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.stream.Collectors; +import java.util.stream.Stream; + + +public class ExponentTests extends AbstractXContentTestCase { + + Exponent createTestInstance(int numberOfWeights) { + return new Exponent(Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList())); + } + + @Override + protected Exponent doParseInstance(XContentParser parser) throws IOException { + return Exponent.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Exponent createTestInstance() { + return randomBoolean() ? new Exponent(null) : createTestInstance(randomIntBetween(1, 100)); + } + +} diff --git a/docs/reference/ml/df-analytics/apis/put-inference.asciidoc b/docs/reference/ml/df-analytics/apis/put-inference.asciidoc index 5ca72acc1aa..176ab7e662a 100644 --- a/docs/reference/ml/df-analytics/apis/put-inference.asciidoc +++ b/docs/reference/ml/df-analytics/apis/put-inference.asciidoc @@ -27,7 +27,7 @@ experimental[] [[ml-put-inference-prereq]] ==== {api-prereq-title} -If the {es} {security-features} are enabled, you must have the following +If the {es} {security-features} are enabled, you must have the following built-in roles and privileges: * `machine_learning_admin` @@ -38,15 +38,15 @@ For more information, see <> and <>. [[ml-put-inference-desc]] ==== {api-description-title} -The create {infer} trained model API enables you to supply a trained model that -is not created by {dfanalytics}. +The create {infer} trained model API enables you to supply a trained model that +is not created by {dfanalytics}. [[ml-put-inference-path-params]] ==== {api-path-parms-title} ``:: -(Required, string) +(Required, string) include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id] [role="child_attributes"] @@ -54,14 +54,14 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id] ==== {api-request-body-title} `compressed_definition`:: -(Required, string) -The compressed (GZipped and Base64 encoded) {infer} definition of the model. +(Required, string) +The compressed (GZipped and Base64 encoded) {infer} definition of the model. If `compressed_definition` is specified, then `definition` cannot be specified. //Begin definition `definition`:: -(Required, object) -The {infer} definition for the model. If `definition` is specified, then +(Required, object) +The {infer} definition for the model. If `definition` is specified, then `compressed_definition` cannot be specified. + .Properties of `definition` @@ -77,58 +77,58 @@ Collection of preprocessors. See <>. ===== //Begin frequency encoding `frequency_encoding`:: -(Required, object) +(Required, object) Defines a frequency encoding for a field. + .Properties of `frequency_encoding` [%collapsible%open] ====== `feature_name`:: -(Required, string) +(Required, string) The name of the resulting feature. `field`:: -(Required, string) +(Required, string) The field name to encode. `frequency_map`:: -(Required, object map of string:double) +(Required, object map of string:double) Object that maps the field value to the frequency encoded value. ====== //End frequency encoding //Begin one hot encoding `one_hot_encoding`:: -(Required, object) +(Required, object) Defines a one hot encoding map for a field. + .Properties of `one_hot_encoding` [%collapsible%open] ====== `field`:: -(Required, string) +(Required, string) The field name to encode. `hot_map`:: -(Required, object map of strings) +(Required, object map of strings) String map of "field_value: one_hot_column_name". ====== //End one hot encoding //Begin target mean encoding `target_mean_encoding`:: -(Required, object) +(Required, object) Defines a target mean encoding for a field. + .Properties of `target_mean_encoding` [%collapsible%open] ====== `default_value`::: -(Required, double) +(Required, double) The feature value if the field value is not in the `target_map`. `feature_name`::: -(Required, string) +(Required, string) The name of the resulting feature. `field`::: @@ -136,7 +136,7 @@ The name of the resulting feature. The field name to encode. `target_map`::: -(Required, object map of string:double) +(Required, object map of string:double) Object that maps the field value to the target mean value. ====== //End target mean encoding @@ -145,7 +145,7 @@ Object that maps the field value to the target mean value. //Begin trained model `trained_model`:: -(Required, object) +(Required, object) The definition of the trained model. + .Properties of `trained_model` @@ -153,14 +153,14 @@ The definition of the trained model. ===== //Begin tree `tree`:: -(Required, object) +(Required, object) The definition for a binary decision tree. + .Properties of `tree` [%collapsible%open] ====== `classification_labels`::: -(Optional, string) An array of classification labels (used for +(Optional, string) An array of classification labels (used for `classification`). `feature_names`::: @@ -168,26 +168,26 @@ The definition for a binary decision tree. Features expected by the tree, in their expected order. `target_type`::: -(Required, string) +(Required, string) String indicating the model target type; `regression` or `classification`. `tree_structure`::: -(Required, object) -An array of `tree_node` objects. The nodes must be in ordinal order by their +(Required, object) +An array of `tree_node` objects. The nodes must be in ordinal order by their `tree_node.node_index` value. ====== //End tree //Begin tree node `tree_node`:: -(Required, object) +(Required, object) The definition of a node in a tree. + -- There are two major types of nodes: leaf nodes and not-leaf nodes. * Leaf nodes only need `node_index` and `leaf_value` defined. -* All other nodes need `split_feature`, `left_child`, `right_child`, +* All other nodes need `split_feature`, `left_child`, `right_child`, `threshold`, `decision_type`, and `default_left` defined. -- + @@ -195,41 +195,41 @@ There are two major types of nodes: leaf nodes and not-leaf nodes. [%collapsible%open] ====== `decision_type`:: -(Optional, string) -Indicates the positive value (in other words, when to choose the left node) +(Optional, string) +Indicates the positive value (in other words, when to choose the left node) decision type. Supported `lt`, `lte`, `gt`, `gte`. Defaults to `lte`. `default_left`:: -(Optional, boolean) -Indicates whether to default to the left when the feature is missing. Defaults +(Optional, boolean) +Indicates whether to default to the left when the feature is missing. Defaults to `true`. `leaf_value`:: -(Optional, double) -The leaf value of the of the node, if the value is a leaf (in other words, no +(Optional, double) +The leaf value of the of the node, if the value is a leaf (in other words, no children). `left_child`:: -(Optional, integer) +(Optional, integer) The index of the left child. `node_index`:: -(Integer) +(Integer) The index of the current node. `right_child`:: -(Optional, integer) +(Optional, integer) The index of the right child. `split_feature`:: -(Optional, integer) +(Optional, integer) The index of the feature value in the feature array. `split_gain`:: (Optional, double) The information gain from the split. `threshold`:: -(Optional, double) +(Optional, double) The decision threshold with which to compare the feature value. ====== //End tree node @@ -244,9 +244,9 @@ The definition for an ensemble model. See <>. ====== //Begin aggregate output `aggregate_output`:: -(Required, object) -An aggregated output object that defines how to aggregate the outputs of the -`trained_models`. Supported objects are `weighted_mode`, `weighted_sum`, and +(Required, object) +An aggregated output object that defines how to aggregate the outputs of the +`trained_models`. Supported objects are `weighted_mode`, `weighted_sum`, and `logistic_regression`. See <>. + .Properties of `aggregate_output` @@ -254,65 +254,82 @@ An aggregated output object that defines how to aggregate the outputs of the ======= //Begin logistic regression `logistic_regression`:: -(Optional, object) -This `aggregated_output` type works with binary classification (classification -for values [0, 1]). It multiplies the outputs (in the case of the `ensemble` -model, the inference model values) by the supplied `weights`. The resulting -vector is summed and passed to a -https://en.wikipedia.org/wiki/Sigmoid_function[`sigmoid` function]. The result -of the `sigmoid` function is considered the probability of class 1 (`P_1`), -consequently, the probability of class 0 is `1 - P_1`. The class with the -highest probability (either 0 or 1) is then returned. For more information about -logistic regression, see +(Optional, object) +This `aggregated_output` type works with binary classification (classification +for values [0, 1]). It multiplies the outputs (in the case of the `ensemble` +model, the inference model values) by the supplied `weights`. The resulting +vector is summed and passed to a +https://en.wikipedia.org/wiki/Sigmoid_function[`sigmoid` function]. The result +of the `sigmoid` function is considered the probability of class 1 (`P_1`), +consequently, the probability of class 0 is `1 - P_1`. The class with the +highest probability (either 0 or 1) is then returned. For more information about +logistic regression, see https://en.wikipedia.org/wiki/Logistic_regression[this wiki article]. + .Properties of `logistic_regression` [%collapsible%open] ======== `weights`::: -(Required, double) -The weights to multiply by the input values (the inference values of the trained +(Required, double) +The weights to multiply by the input values (the inference values of the trained models). ======== //End logistic regression //Begin weighted sum `weighted_sum`:: -(Optional, object) -This `aggregated_output` type works with regression. The weighted sum of the +(Optional, object) +This `aggregated_output` type works with regression. The weighted sum of the input values. + .Properties of `weighted_sum` [%collapsible%open] ======== `weights`::: -(Required, double) -The weights to multiply by the input values (the inference values of the trained +(Required, double) +The weights to multiply by the input values (the inference values of the trained models). ======== //End weighted sum //Begin weighted mode `weighted_mode`:: -(Optional, object) -This `aggregated_output` type works with regression or classification. It takes -a weighted vote of the input values. The most common input value (taking the +(Optional, object) +This `aggregated_output` type works with regression or classification. It takes +a weighted vote of the input values. The most common input value (taking the weights into account) is returned. + .Properties of `weighted_mode` [%collapsible%open] ======== `weights`::: -(Required, double) -The weights to multiply by the input values (the inference values of the trained +(Required, double) +The weights to multiply by the input values (the inference values of the trained models). ======== //End weighted mode + +//Begin exponent +`exponent`:: +(Optional, object) +This `aggregated_output` type works with regression. It takes a weighted sum of +the input values and passes the result to an exponent function +(`e^x` where `x` is the sum of the weighted values). ++ +.Properties of `exponent` +[%collapsible%open] +======== +`weights`::: +(Required, double) +The weights to multiply by the input values (the inference values of the trained +models). +======== +//End exponent ======= //End aggregate output `classification_labels`:: -(Optional, string) +(Optional, string) An array of classification labels. `feature_names`:: @@ -320,12 +337,12 @@ An array of classification labels. Features expected by the ensemble, in their expected order. `target_type`:: -(Required, string) +(Required, string) String indicating the model target type; `regression` or `classification.` `trained_models`:: (Required, object) -An array of `trained_model` objects. Supported trained models are `tree` and +An array of `trained_model` objects. Supported trained models are `tree` and `ensemble`. ====== //End ensemble @@ -337,7 +354,7 @@ An array of `trained_model` objects. Supported trained models are `tree` and //End definition `description`:: -(Optional, string) +(Optional, string) A human-readable description of the {infer} trained model. //Begin inference_config @@ -398,24 +415,24 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-classification //Begin input `input`:: -(Required, object) +(Required, object) The input field names for the model definition. + .Properties of `input` [%collapsible%open] ==== `field_names`::: -(Required, string) +(Required, string) An array of input field names for the model. ==== //End input `metadata`:: -(Optional, object) +(Optional, object) An object map that contains metadata about the model. `tags`:: -(Optional, string) +(Optional, string) An array of tags to organize the model. @@ -451,10 +468,10 @@ The next example shows a `one_hot_encoding` preprocessor object: [source,js] ---------------------------------- -{ - "one_hot_encoding":{ +{ + "one_hot_encoding":{ "field":"FlightDelayType", - "hot_map":{ + "hot_map":{ "Carrier Delay":"FlightDelayType_Carrier Delay", "NAS Delay":"FlightDelayType_NAS Delay", "No Delay":"FlightDelayType_No Delay", @@ -521,7 +538,7 @@ The first example shows a `trained_model` object: "left_child":1, "right_child":2 }, - ... + ... { "node_index":9, "leaf_value":-27.68987349695448 @@ -615,8 +632,21 @@ Example of a `weighted_mode` object: //NOTCONSOLE +Example of an `exponent` object: + +[source,js] +---------------------------------- +"aggregate_output" : { + "exponent" : { + "weights" : [1.0, 1.0, 1.0, 1.0, 1.0] + } +} +---------------------------------- +//NOTCONSOLE + + [[ml-put-inference-json-schema]] ===== {infer-cap} JSON schema -For the full JSON schema of model {infer}, +For the full JSON schema of model {infer}, https://github.com/elastic/ml-json-schemas[click here]. diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 565e2c67ff4..203329da131 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInfe import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Exponent; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LenientlyParsedOutputAggregator; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LogisticRegression; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator; @@ -90,6 +91,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider { namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedOutputAggregator.class, LogisticRegression.NAME, LogisticRegression::fromXContentLenient)); + namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedOutputAggregator.class, + Exponent.NAME, + Exponent::fromXContentLenient)); // Model Strict namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentStrict)); @@ -108,6 +112,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider { namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedOutputAggregator.class, LogisticRegression.NAME, LogisticRegression::fromXContentStrict)); + namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedOutputAggregator.class, + Exponent.NAME, + Exponent::fromXContentStrict)); // Inference Configs namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedInferenceConfig.class, ClassificationConfig.NAME, @@ -166,6 +173,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider { namedWriteables.add(new NamedWriteableRegistry.Entry(OutputAggregator.class, LogisticRegression.NAME.getPreferredName(), LogisticRegression::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(OutputAggregator.class, + Exponent.NAME.getPreferredName(), + Exponent::new)); // Inference Results namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Exponent.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Exponent.java new file mode 100644 index 00000000000..ae7098c6226 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Exponent.java @@ -0,0 +1,157 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +public class Exponent implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator { + + public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Exponent.class); + public static final ParseField NAME = new ParseField("exponent"); + public static final ParseField WEIGHTS = new ParseField("weights"); + + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + + @SuppressWarnings("unchecked") + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>( + NAME.getPreferredName(), + lenient, + a -> new Exponent((List)a[0])); + parser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS); + return parser; + } + + public static Exponent fromXContentStrict(XContentParser parser) { + return STRICT_PARSER.apply(parser, null); + } + + public static Exponent fromXContentLenient(XContentParser parser) { + return LENIENT_PARSER.apply(parser, null); + } + + private final double[] weights; + + Exponent() { + this((List) null); + } + + private Exponent(List weights) { + this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray()); + } + + public Exponent(double[] weights) { + this.weights = weights; + } + + public Exponent(StreamInput in) throws IOException { + if (in.readBoolean()) { + this.weights = in.readDoubleArray(); + } else { + this.weights = null; + } + } + + @Override + public Integer expectedValueSize() { + return this.weights == null ? null : this.weights.length; + } + + @Override + public double[] processValues(double[][] values) { + Objects.requireNonNull(values, "values must not be null"); + if (weights != null && values.length != weights.length) { + throw new IllegalArgumentException("values must be the same length as weights."); + } + assert values[0].length == 1; + double[] processed = new double[values.length]; + for (int i = 0; i < values.length; ++i) { + if (weights != null) { + processed[i] = weights[i] * values[i][0]; + } else { + processed[i] = values[i][0]; + } + } + return processed; + } + + @Override + public double aggregate(double[] values) { + Objects.requireNonNull(values, "values must not be null"); + double sum = 0.0; + for (double val : values) { + if (Double.isFinite(val)) { + sum += val; + } + } + return Math.exp(sum); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public boolean compatibleWith(TargetType targetType) { + return TargetType.REGRESSION.equals(targetType); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeBoolean(weights != null); + if (weights != null) { + out.writeDoubleArray(weights); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (weights != null) { + builder.field(WEIGHTS.getPreferredName(), weights); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Exponent that = (Exponent) o; + return Arrays.equals(weights, that.weights); + } + + @Override + public int hashCode() { + return Arrays.hashCode(weights); + } + + @Override + public long ramBytesUsed() { + long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights); + return SHALLOW_SIZE + weightSize; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index 6717ef32c20..51179db17f4 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -80,7 +80,8 @@ public class EnsembleTests extends AbstractSerializingTestCase { categoryLabels = randomList(2, randomIntBetween(3, 10), () -> randomAlphaOfLength(10)); } - OutputAggregator outputAggregator = targetType == TargetType.REGRESSION ? new WeightedSum(weights) : + OutputAggregator outputAggregator = targetType == TargetType.REGRESSION ? + randomFrom(new WeightedSum(weights), new Exponent(weights)) : randomFrom( new WeightedMode( weights, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/ExponentTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/ExponentTests.java new file mode 100644 index 00000000000..d4e196827d6 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/ExponentTests.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; + +import java.io.IOException; +import java.util.stream.Stream; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.closeTo; + +public class ExponentTests extends WeightedAggregatorTests { + + @Override + Exponent createTestInstance(int numberOfWeights) { + double[] weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).mapToDouble(Double::valueOf).toArray(); + return new Exponent(weights); + } + + @Override + protected Exponent doParseInstance(XContentParser parser) throws IOException { + return lenient ? Exponent.fromXContentLenient(parser) : Exponent.fromXContentStrict(parser); + } + + @Override + protected Exponent createTestInstance() { + return randomBoolean() ? new Exponent() : createTestInstance(randomIntBetween(1, 100)); + } + + @Override + protected Writeable.Reader instanceReader() { + return Exponent::new; + } + + public void testAggregate() { + double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0}; + double[][] values = new double[][]{ + new double[] {.01}, + new double[] {.2}, + new double[] {.002}, + new double[] {-.01}, + new double[] {.1} + }; + + Exponent exponent = new Exponent(ones); + assertThat(exponent.aggregate(exponent.processValues(values)), closeTo(1.35256, 0.00001)); + + double[] variedWeights = new double[]{.01, -1.0, .1, 0.0, 0.0}; + + exponent = new Exponent(variedWeights); + assertThat(exponent.aggregate(exponent.processValues(values)), closeTo(0.81897, 0.00001)); + + exponent = new Exponent(); + assertThat(exponent.aggregate(exponent.processValues(values)), closeTo(1.35256, 0.00001)); + } + + public void testCompatibleWith() { + Exponent exponent = createTestInstance(); + assertThat(exponent.compatibleWith(TargetType.CLASSIFICATION), is(false)); + assertThat(exponent.compatibleWith(TargetType.REGRESSION), is(true)); + } +}