[ML] add exponent output aggregator to inference (#58933) (#59016)

* [ML] add exponent output aggregator to inference

* fixing docs

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
Benjamin Trent 2020-07-03 14:51:00 -04:00 committed by GitHub
parent 935a49a8d6
commit b9d9964d10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 487 additions and 80 deletions

View File

@ -24,6 +24,7 @@ import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Exponent;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode;
@ -82,6 +83,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class,
new ParseField(LogisticRegression.NAME),
LogisticRegression::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class,
new ParseField(Exponent.NAME),
Exponent::fromXContent));
return namedXContent;
}

View File

@ -0,0 +1,83 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
public class Exponent implements OutputAggregator {
public static final String NAME = "exponent";
public static final ParseField WEIGHTS = new ParseField("weights");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Exponent, Void> PARSER = new ConstructingObjectParser<>(
NAME,
true,
a -> new Exponent((List<Double>)a[0]));
static {
PARSER.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
}
public static Exponent fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final List<Double> weights;
public Exponent(List<Double> weights) {
this.weights = weights;
}
@Override
public String getName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (weights != null) {
builder.field(WEIGHTS.getPreferredName(), weights);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Exponent that = (Exponent) o;
return Objects.equals(weights, that.weights);
}
@Override
public int hashCode() {
return Objects.hash(weights);
}
}

View File

@ -80,6 +80,7 @@ import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncoding;
import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Exponent;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
@ -703,7 +704,7 @@ public class RestHighLevelClientTests extends ESTestCase {
public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(68, namedXContents.size());
assertEquals(69, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@ -788,9 +789,9 @@ public class RestHighLevelClientTests extends ESTestCase {
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME));
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
assertThat(names, hasItems(Tree.NAME, Ensemble.NAME, LangIdentNeuralNetwork.NAME));
assertEquals(Integer.valueOf(3),
assertEquals(Integer.valueOf(4),
categories.get(org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator.class));
assertThat(names, hasItems(WeightedMode.NAME, WeightedSum.NAME, LogisticRegression.NAME));
assertThat(names, hasItems(WeightedMode.NAME, WeightedSum.NAME, LogisticRegression.NAME, Exponent.NAME));
assertEquals(Integer.valueOf(2),
categories.get(org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig.class));
assertThat(names, hasItems(ClassificationConfig.NAME.getPreferredName(), RegressionConfig.NAME.getPreferredName()));

View File

@ -73,7 +73,8 @@ public class EnsembleTests extends AbstractXContentTestCase<Ensemble> {
categoryLabels = randomList(2, randomIntBetween(3, 10), () -> randomAlphaOfLength(10));
}
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
OutputAggregator outputAggregator = targetType == TargetType.REGRESSION ? new WeightedSum(weights) :
OutputAggregator outputAggregator = targetType == TargetType.REGRESSION ?
randomFrom(new WeightedSum(weights), new Exponent(weights)) :
randomFrom(
new WeightedMode(
categoryLabels != null ? categoryLabels.size() : randomIntBetween(2, 10),

View File

@ -0,0 +1,51 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class ExponentTests extends AbstractXContentTestCase<Exponent> {
Exponent createTestInstance(int numberOfWeights) {
return new Exponent(Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()));
}
@Override
protected Exponent doParseInstance(XContentParser parser) throws IOException {
return Exponent.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected Exponent createTestInstance() {
return randomBoolean() ? new Exponent(null) : createTestInstance(randomIntBetween(1, 100));
}
}

View File

@ -27,7 +27,7 @@ experimental[]
[[ml-put-inference-prereq]]
==== {api-prereq-title}
If the {es} {security-features} are enabled, you must have the following
If the {es} {security-features} are enabled, you must have the following
built-in roles and privileges:
* `machine_learning_admin`
@ -38,15 +38,15 @@ For more information, see <<security-privileges>> and <<built-in-roles>>.
[[ml-put-inference-desc]]
==== {api-description-title}
The create {infer} trained model API enables you to supply a trained model that
is not created by {dfanalytics}.
The create {infer} trained model API enables you to supply a trained model that
is not created by {dfanalytics}.
[[ml-put-inference-path-params]]
==== {api-path-parms-title}
`<model_id>`::
(Required, string)
(Required, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
[role="child_attributes"]
@ -54,14 +54,14 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
==== {api-request-body-title}
`compressed_definition`::
(Required, string)
The compressed (GZipped and Base64 encoded) {infer} definition of the model.
(Required, string)
The compressed (GZipped and Base64 encoded) {infer} definition of the model.
If `compressed_definition` is specified, then `definition` cannot be specified.
//Begin definition
`definition`::
(Required, object)
The {infer} definition for the model. If `definition` is specified, then
(Required, object)
The {infer} definition for the model. If `definition` is specified, then
`compressed_definition` cannot be specified.
+
.Properties of `definition`
@ -77,58 +77,58 @@ Collection of preprocessors. See <<ml-put-inference-preprocessor-example>>.
=====
//Begin frequency encoding
`frequency_encoding`::
(Required, object)
(Required, object)
Defines a frequency encoding for a field.
+
.Properties of `frequency_encoding`
[%collapsible%open]
======
`feature_name`::
(Required, string)
(Required, string)
The name of the resulting feature.
`field`::
(Required, string)
(Required, string)
The field name to encode.
`frequency_map`::
(Required, object map of string:double)
(Required, object map of string:double)
Object that maps the field value to the frequency encoded value.
======
//End frequency encoding
//Begin one hot encoding
`one_hot_encoding`::
(Required, object)
(Required, object)
Defines a one hot encoding map for a field.
+
.Properties of `one_hot_encoding`
[%collapsible%open]
======
`field`::
(Required, string)
(Required, string)
The field name to encode.
`hot_map`::
(Required, object map of strings)
(Required, object map of strings)
String map of "field_value: one_hot_column_name".
======
//End one hot encoding
//Begin target mean encoding
`target_mean_encoding`::
(Required, object)
(Required, object)
Defines a target mean encoding for a field.
+
.Properties of `target_mean_encoding`
[%collapsible%open]
======
`default_value`:::
(Required, double)
(Required, double)
The feature value if the field value is not in the `target_map`.
`feature_name`:::
(Required, string)
(Required, string)
The name of the resulting feature.
`field`:::
@ -136,7 +136,7 @@ The name of the resulting feature.
The field name to encode.
`target_map`:::
(Required, object map of string:double)
(Required, object map of string:double)
Object that maps the field value to the target mean value.
======
//End target mean encoding
@ -145,7 +145,7 @@ Object that maps the field value to the target mean value.
//Begin trained model
`trained_model`::
(Required, object)
(Required, object)
The definition of the trained model.
+
.Properties of `trained_model`
@ -153,14 +153,14 @@ The definition of the trained model.
=====
//Begin tree
`tree`::
(Required, object)
(Required, object)
The definition for a binary decision tree.
+
.Properties of `tree`
[%collapsible%open]
======
`classification_labels`:::
(Optional, string) An array of classification labels (used for
(Optional, string) An array of classification labels (used for
`classification`).
`feature_names`:::
@ -168,26 +168,26 @@ The definition for a binary decision tree.
Features expected by the tree, in their expected order.
`target_type`:::
(Required, string)
(Required, string)
String indicating the model target type; `regression` or `classification`.
`tree_structure`:::
(Required, object)
An array of `tree_node` objects. The nodes must be in ordinal order by their
(Required, object)
An array of `tree_node` objects. The nodes must be in ordinal order by their
`tree_node.node_index` value.
======
//End tree
//Begin tree node
`tree_node`::
(Required, object)
(Required, object)
The definition of a node in a tree.
+
--
There are two major types of nodes: leaf nodes and not-leaf nodes.
* Leaf nodes only need `node_index` and `leaf_value` defined.
* All other nodes need `split_feature`, `left_child`, `right_child`,
* All other nodes need `split_feature`, `left_child`, `right_child`,
`threshold`, `decision_type`, and `default_left` defined.
--
+
@ -195,41 +195,41 @@ There are two major types of nodes: leaf nodes and not-leaf nodes.
[%collapsible%open]
======
`decision_type`::
(Optional, string)
Indicates the positive value (in other words, when to choose the left node)
(Optional, string)
Indicates the positive value (in other words, when to choose the left node)
decision type. Supported `lt`, `lte`, `gt`, `gte`. Defaults to `lte`.
`default_left`::
(Optional, boolean)
Indicates whether to default to the left when the feature is missing. Defaults
(Optional, boolean)
Indicates whether to default to the left when the feature is missing. Defaults
to `true`.
`leaf_value`::
(Optional, double)
The leaf value of the of the node, if the value is a leaf (in other words, no
(Optional, double)
The leaf value of the of the node, if the value is a leaf (in other words, no
children).
`left_child`::
(Optional, integer)
(Optional, integer)
The index of the left child.
`node_index`::
(Integer)
(Integer)
The index of the current node.
`right_child`::
(Optional, integer)
(Optional, integer)
The index of the right child.
`split_feature`::
(Optional, integer)
(Optional, integer)
The index of the feature value in the feature array.
`split_gain`::
(Optional, double) The information gain from the split.
`threshold`::
(Optional, double)
(Optional, double)
The decision threshold with which to compare the feature value.
======
//End tree node
@ -244,9 +244,9 @@ The definition for an ensemble model. See <<ml-put-inference-model-example>>.
======
//Begin aggregate output
`aggregate_output`::
(Required, object)
An aggregated output object that defines how to aggregate the outputs of the
`trained_models`. Supported objects are `weighted_mode`, `weighted_sum`, and
(Required, object)
An aggregated output object that defines how to aggregate the outputs of the
`trained_models`. Supported objects are `weighted_mode`, `weighted_sum`, and
`logistic_regression`. See <<ml-put-inference-aggregated-output-example>>.
+
.Properties of `aggregate_output`
@ -254,65 +254,82 @@ An aggregated output object that defines how to aggregate the outputs of the
=======
//Begin logistic regression
`logistic_regression`::
(Optional, object)
This `aggregated_output` type works with binary classification (classification
for values [0, 1]). It multiplies the outputs (in the case of the `ensemble`
model, the inference model values) by the supplied `weights`. The resulting
vector is summed and passed to a
https://en.wikipedia.org/wiki/Sigmoid_function[`sigmoid` function]. The result
of the `sigmoid` function is considered the probability of class 1 (`P_1`),
consequently, the probability of class 0 is `1 - P_1`. The class with the
highest probability (either 0 or 1) is then returned. For more information about
logistic regression, see
(Optional, object)
This `aggregated_output` type works with binary classification (classification
for values [0, 1]). It multiplies the outputs (in the case of the `ensemble`
model, the inference model values) by the supplied `weights`. The resulting
vector is summed and passed to a
https://en.wikipedia.org/wiki/Sigmoid_function[`sigmoid` function]. The result
of the `sigmoid` function is considered the probability of class 1 (`P_1`),
consequently, the probability of class 0 is `1 - P_1`. The class with the
highest probability (either 0 or 1) is then returned. For more information about
logistic regression, see
https://en.wikipedia.org/wiki/Logistic_regression[this wiki article].
+
.Properties of `logistic_regression`
[%collapsible%open]
========
`weights`:::
(Required, double)
The weights to multiply by the input values (the inference values of the trained
(Required, double)
The weights to multiply by the input values (the inference values of the trained
models).
========
//End logistic regression
//Begin weighted sum
`weighted_sum`::
(Optional, object)
This `aggregated_output` type works with regression. The weighted sum of the
(Optional, object)
This `aggregated_output` type works with regression. The weighted sum of the
input values.
+
.Properties of `weighted_sum`
[%collapsible%open]
========
`weights`:::
(Required, double)
The weights to multiply by the input values (the inference values of the trained
(Required, double)
The weights to multiply by the input values (the inference values of the trained
models).
========
//End weighted sum
//Begin weighted mode
`weighted_mode`::
(Optional, object)
This `aggregated_output` type works with regression or classification. It takes
a weighted vote of the input values. The most common input value (taking the
(Optional, object)
This `aggregated_output` type works with regression or classification. It takes
a weighted vote of the input values. The most common input value (taking the
weights into account) is returned.
+
.Properties of `weighted_mode`
[%collapsible%open]
========
`weights`:::
(Required, double)
The weights to multiply by the input values (the inference values of the trained
(Required, double)
The weights to multiply by the input values (the inference values of the trained
models).
========
//End weighted mode
//Begin exponent
`exponent`::
(Optional, object)
This `aggregated_output` type works with regression. It takes a weighted sum of
the input values and passes the result to an exponent function
(`e^x` where `x` is the sum of the weighted values).
+
.Properties of `exponent`
[%collapsible%open]
========
`weights`:::
(Required, double)
The weights to multiply by the input values (the inference values of the trained
models).
========
//End exponent
=======
//End aggregate output
`classification_labels`::
(Optional, string)
(Optional, string)
An array of classification labels.
`feature_names`::
@ -320,12 +337,12 @@ An array of classification labels.
Features expected by the ensemble, in their expected order.
`target_type`::
(Required, string)
(Required, string)
String indicating the model target type; `regression` or `classification.`
`trained_models`::
(Required, object)
An array of `trained_model` objects. Supported trained models are `tree` and
An array of `trained_model` objects. Supported trained models are `tree` and
`ensemble`.
======
//End ensemble
@ -337,7 +354,7 @@ An array of `trained_model` objects. Supported trained models are `tree` and
//End definition
`description`::
(Optional, string)
(Optional, string)
A human-readable description of the {infer} trained model.
//Begin inference_config
@ -398,24 +415,24 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-classification
//Begin input
`input`::
(Required, object)
(Required, object)
The input field names for the model definition.
+
.Properties of `input`
[%collapsible%open]
====
`field_names`:::
(Required, string)
(Required, string)
An array of input field names for the model.
====
//End input
`metadata`::
(Optional, object)
(Optional, object)
An object map that contains metadata about the model.
`tags`::
(Optional, string)
(Optional, string)
An array of tags to organize the model.
@ -451,10 +468,10 @@ The next example shows a `one_hot_encoding` preprocessor object:
[source,js]
----------------------------------
{
"one_hot_encoding":{
{
"one_hot_encoding":{
"field":"FlightDelayType",
"hot_map":{
"hot_map":{
"Carrier Delay":"FlightDelayType_Carrier Delay",
"NAS Delay":"FlightDelayType_NAS Delay",
"No Delay":"FlightDelayType_No Delay",
@ -521,7 +538,7 @@ The first example shows a `trained_model` object:
"left_child":1,
"right_child":2
},
...
...
{
"node_index":9,
"leaf_value":-27.68987349695448
@ -615,8 +632,21 @@ Example of a `weighted_mode` object:
//NOTCONSOLE
Example of an `exponent` object:
[source,js]
----------------------------------
"aggregate_output" : {
"exponent" : {
"weights" : [1.0, 1.0, 1.0, 1.0, 1.0]
}
}
----------------------------------
//NOTCONSOLE
[[ml-put-inference-json-schema]]
===== {infer-cap} JSON schema
For the full JSON schema of model {infer},
For the full JSON schema of model {infer},
https://github.com/elastic/ml-json-schemas[click here].

View File

@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInfe
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Exponent;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LenientlyParsedOutputAggregator;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LogisticRegression;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
@ -90,6 +91,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedOutputAggregator.class,
LogisticRegression.NAME,
LogisticRegression::fromXContentLenient));
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedOutputAggregator.class,
Exponent.NAME,
Exponent::fromXContentLenient));
// Model Strict
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentStrict));
@ -108,6 +112,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedOutputAggregator.class,
LogisticRegression.NAME,
LogisticRegression::fromXContentStrict));
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedOutputAggregator.class,
Exponent.NAME,
Exponent::fromXContentStrict));
// Inference Configs
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedInferenceConfig.class, ClassificationConfig.NAME,
@ -166,6 +173,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
namedWriteables.add(new NamedWriteableRegistry.Entry(OutputAggregator.class,
LogisticRegression.NAME.getPreferredName(),
LogisticRegression::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(OutputAggregator.class,
Exponent.NAME.getPreferredName(),
Exponent::new));
// Inference Results
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,

View File

@ -0,0 +1,157 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
public class Exponent implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Exponent.class);
public static final ParseField NAME = new ParseField("exponent");
public static final ParseField WEIGHTS = new ParseField("weights");
private static final ConstructingObjectParser<Exponent, Void> LENIENT_PARSER = createParser(true);
private static final ConstructingObjectParser<Exponent, Void> STRICT_PARSER = createParser(false);
@SuppressWarnings("unchecked")
private static ConstructingObjectParser<Exponent, Void> createParser(boolean lenient) {
ConstructingObjectParser<Exponent, Void> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(),
lenient,
a -> new Exponent((List<Double>)a[0]));
parser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
return parser;
}
public static Exponent fromXContentStrict(XContentParser parser) {
return STRICT_PARSER.apply(parser, null);
}
public static Exponent fromXContentLenient(XContentParser parser) {
return LENIENT_PARSER.apply(parser, null);
}
private final double[] weights;
Exponent() {
this((List<Double>) null);
}
private Exponent(List<Double> weights) {
this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray());
}
public Exponent(double[] weights) {
this.weights = weights;
}
public Exponent(StreamInput in) throws IOException {
if (in.readBoolean()) {
this.weights = in.readDoubleArray();
} else {
this.weights = null;
}
}
@Override
public Integer expectedValueSize() {
return this.weights == null ? null : this.weights.length;
}
@Override
public double[] processValues(double[][] values) {
Objects.requireNonNull(values, "values must not be null");
if (weights != null && values.length != weights.length) {
throw new IllegalArgumentException("values must be the same length as weights.");
}
assert values[0].length == 1;
double[] processed = new double[values.length];
for (int i = 0; i < values.length; ++i) {
if (weights != null) {
processed[i] = weights[i] * values[i][0];
} else {
processed[i] = values[i][0];
}
}
return processed;
}
@Override
public double aggregate(double[] values) {
Objects.requireNonNull(values, "values must not be null");
double sum = 0.0;
for (double val : values) {
if (Double.isFinite(val)) {
sum += val;
}
}
return Math.exp(sum);
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public boolean compatibleWith(TargetType targetType) {
return TargetType.REGRESSION.equals(targetType);
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(weights != null);
if (weights != null) {
out.writeDoubleArray(weights);
}
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (weights != null) {
builder.field(WEIGHTS.getPreferredName(), weights);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Exponent that = (Exponent) o;
return Arrays.equals(weights, that.weights);
}
@Override
public int hashCode() {
return Arrays.hashCode(weights);
}
@Override
public long ramBytesUsed() {
long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights);
return SHALLOW_SIZE + weightSize;
}
}

View File

@ -80,7 +80,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
categoryLabels = randomList(2, randomIntBetween(3, 10), () -> randomAlphaOfLength(10));
}
OutputAggregator outputAggregator = targetType == TargetType.REGRESSION ? new WeightedSum(weights) :
OutputAggregator outputAggregator = targetType == TargetType.REGRESSION ?
randomFrom(new WeightedSum(weights), new Exponent(weights)) :
randomFrom(
new WeightedMode(
weights,

View File

@ -0,0 +1,69 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import java.io.IOException;
import java.util.stream.Stream;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.closeTo;
public class ExponentTests extends WeightedAggregatorTests<Exponent> {
@Override
Exponent createTestInstance(int numberOfWeights) {
double[] weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).mapToDouble(Double::valueOf).toArray();
return new Exponent(weights);
}
@Override
protected Exponent doParseInstance(XContentParser parser) throws IOException {
return lenient ? Exponent.fromXContentLenient(parser) : Exponent.fromXContentStrict(parser);
}
@Override
protected Exponent createTestInstance() {
return randomBoolean() ? new Exponent() : createTestInstance(randomIntBetween(1, 100));
}
@Override
protected Writeable.Reader<Exponent> instanceReader() {
return Exponent::new;
}
public void testAggregate() {
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
double[][] values = new double[][]{
new double[] {.01},
new double[] {.2},
new double[] {.002},
new double[] {-.01},
new double[] {.1}
};
Exponent exponent = new Exponent(ones);
assertThat(exponent.aggregate(exponent.processValues(values)), closeTo(1.35256, 0.00001));
double[] variedWeights = new double[]{.01, -1.0, .1, 0.0, 0.0};
exponent = new Exponent(variedWeights);
assertThat(exponent.aggregate(exponent.processValues(values)), closeTo(0.81897, 0.00001));
exponent = new Exponent();
assertThat(exponent.aggregate(exponent.processValues(values)), closeTo(1.35256, 0.00001));
}
public void testCompatibleWith() {
Exponent exponent = createTestInstance();
assertThat(exponent.compatibleWith(TargetType.CLASSIFICATION), is(false));
assertThat(exponent.compatibleWith(TargetType.REGRESSION), is(true));
}
}