* [ML] add exponent output aggregator to inference * fixing docs Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
parent
935a49a8d6
commit
b9d9964d10
|
@ -24,6 +24,7 @@ import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig;
|
|||
import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Exponent;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode;
|
||||
|
@ -82,6 +83,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
|
|||
namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class,
|
||||
new ParseField(LogisticRegression.NAME),
|
||||
LogisticRegression::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class,
|
||||
new ParseField(Exponent.NAME),
|
||||
Exponent::fromXContent));
|
||||
|
||||
return namedXContent;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
|
||||
public class Exponent implements OutputAggregator {
|
||||
|
||||
public static final String NAME = "exponent";
|
||||
public static final ParseField WEIGHTS = new ParseField("weights");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<Exponent, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME,
|
||||
true,
|
||||
a -> new Exponent((List<Double>)a[0]));
|
||||
static {
|
||||
PARSER.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
|
||||
}
|
||||
|
||||
public static Exponent fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final List<Double> weights;
|
||||
|
||||
public Exponent(List<Double> weights) {
|
||||
this.weights = weights;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (weights != null) {
|
||||
builder.field(WEIGHTS.getPreferredName(), weights);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Exponent that = (Exponent) o;
|
||||
return Objects.equals(weights, that.weights);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(weights);
|
||||
}
|
||||
}
|
|
@ -80,6 +80,7 @@ import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncoding;
|
|||
import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Exponent;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
|
||||
|
@ -703,7 +704,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
|||
|
||||
public void testProvidedNamedXContents() {
|
||||
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
|
||||
assertEquals(68, namedXContents.size());
|
||||
assertEquals(69, namedXContents.size());
|
||||
Map<Class<?>, Integer> categories = new HashMap<>();
|
||||
List<String> names = new ArrayList<>();
|
||||
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
|
||||
|
@ -788,9 +789,9 @@ public class RestHighLevelClientTests extends ESTestCase {
|
|||
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME));
|
||||
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
|
||||
assertThat(names, hasItems(Tree.NAME, Ensemble.NAME, LangIdentNeuralNetwork.NAME));
|
||||
assertEquals(Integer.valueOf(3),
|
||||
assertEquals(Integer.valueOf(4),
|
||||
categories.get(org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator.class));
|
||||
assertThat(names, hasItems(WeightedMode.NAME, WeightedSum.NAME, LogisticRegression.NAME));
|
||||
assertThat(names, hasItems(WeightedMode.NAME, WeightedSum.NAME, LogisticRegression.NAME, Exponent.NAME));
|
||||
assertEquals(Integer.valueOf(2),
|
||||
categories.get(org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig.class));
|
||||
assertThat(names, hasItems(ClassificationConfig.NAME.getPreferredName(), RegressionConfig.NAME.getPreferredName()));
|
||||
|
|
|
@ -73,7 +73,8 @@ public class EnsembleTests extends AbstractXContentTestCase<Ensemble> {
|
|||
categoryLabels = randomList(2, randomIntBetween(3, 10), () -> randomAlphaOfLength(10));
|
||||
}
|
||||
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
|
||||
OutputAggregator outputAggregator = targetType == TargetType.REGRESSION ? new WeightedSum(weights) :
|
||||
OutputAggregator outputAggregator = targetType == TargetType.REGRESSION ?
|
||||
randomFrom(new WeightedSum(weights), new Exponent(weights)) :
|
||||
randomFrom(
|
||||
new WeightedMode(
|
||||
categoryLabels != null ? categoryLabels.size() : randomIntBetween(2, 10),
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
||||
public class ExponentTests extends AbstractXContentTestCase<Exponent> {
|
||||
|
||||
Exponent createTestInstance(int numberOfWeights) {
|
||||
return new Exponent(Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Exponent doParseInstance(XContentParser parser) throws IOException {
|
||||
return Exponent.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Exponent createTestInstance() {
|
||||
return randomBoolean() ? new Exponent(null) : createTestInstance(randomIntBetween(1, 100));
|
||||
}
|
||||
|
||||
}
|
|
@ -27,7 +27,7 @@ experimental[]
|
|||
[[ml-put-inference-prereq]]
|
||||
==== {api-prereq-title}
|
||||
|
||||
If the {es} {security-features} are enabled, you must have the following
|
||||
If the {es} {security-features} are enabled, you must have the following
|
||||
built-in roles and privileges:
|
||||
|
||||
* `machine_learning_admin`
|
||||
|
@ -38,15 +38,15 @@ For more information, see <<security-privileges>> and <<built-in-roles>>.
|
|||
[[ml-put-inference-desc]]
|
||||
==== {api-description-title}
|
||||
|
||||
The create {infer} trained model API enables you to supply a trained model that
|
||||
is not created by {dfanalytics}.
|
||||
The create {infer} trained model API enables you to supply a trained model that
|
||||
is not created by {dfanalytics}.
|
||||
|
||||
|
||||
[[ml-put-inference-path-params]]
|
||||
==== {api-path-parms-title}
|
||||
|
||||
`<model_id>`::
|
||||
(Required, string)
|
||||
(Required, string)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
|
||||
|
||||
[role="child_attributes"]
|
||||
|
@ -54,14 +54,14 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
|
|||
==== {api-request-body-title}
|
||||
|
||||
`compressed_definition`::
|
||||
(Required, string)
|
||||
The compressed (GZipped and Base64 encoded) {infer} definition of the model.
|
||||
(Required, string)
|
||||
The compressed (GZipped and Base64 encoded) {infer} definition of the model.
|
||||
If `compressed_definition` is specified, then `definition` cannot be specified.
|
||||
|
||||
//Begin definition
|
||||
`definition`::
|
||||
(Required, object)
|
||||
The {infer} definition for the model. If `definition` is specified, then
|
||||
(Required, object)
|
||||
The {infer} definition for the model. If `definition` is specified, then
|
||||
`compressed_definition` cannot be specified.
|
||||
+
|
||||
.Properties of `definition`
|
||||
|
@ -77,58 +77,58 @@ Collection of preprocessors. See <<ml-put-inference-preprocessor-example>>.
|
|||
=====
|
||||
//Begin frequency encoding
|
||||
`frequency_encoding`::
|
||||
(Required, object)
|
||||
(Required, object)
|
||||
Defines a frequency encoding for a field.
|
||||
+
|
||||
.Properties of `frequency_encoding`
|
||||
[%collapsible%open]
|
||||
======
|
||||
`feature_name`::
|
||||
(Required, string)
|
||||
(Required, string)
|
||||
The name of the resulting feature.
|
||||
|
||||
`field`::
|
||||
(Required, string)
|
||||
(Required, string)
|
||||
The field name to encode.
|
||||
|
||||
`frequency_map`::
|
||||
(Required, object map of string:double)
|
||||
(Required, object map of string:double)
|
||||
Object that maps the field value to the frequency encoded value.
|
||||
======
|
||||
//End frequency encoding
|
||||
|
||||
//Begin one hot encoding
|
||||
`one_hot_encoding`::
|
||||
(Required, object)
|
||||
(Required, object)
|
||||
Defines a one hot encoding map for a field.
|
||||
+
|
||||
.Properties of `one_hot_encoding`
|
||||
[%collapsible%open]
|
||||
======
|
||||
`field`::
|
||||
(Required, string)
|
||||
(Required, string)
|
||||
The field name to encode.
|
||||
|
||||
`hot_map`::
|
||||
(Required, object map of strings)
|
||||
(Required, object map of strings)
|
||||
String map of "field_value: one_hot_column_name".
|
||||
======
|
||||
//End one hot encoding
|
||||
|
||||
//Begin target mean encoding
|
||||
`target_mean_encoding`::
|
||||
(Required, object)
|
||||
(Required, object)
|
||||
Defines a target mean encoding for a field.
|
||||
+
|
||||
.Properties of `target_mean_encoding`
|
||||
[%collapsible%open]
|
||||
======
|
||||
`default_value`:::
|
||||
(Required, double)
|
||||
(Required, double)
|
||||
The feature value if the field value is not in the `target_map`.
|
||||
|
||||
`feature_name`:::
|
||||
(Required, string)
|
||||
(Required, string)
|
||||
The name of the resulting feature.
|
||||
|
||||
`field`:::
|
||||
|
@ -136,7 +136,7 @@ The name of the resulting feature.
|
|||
The field name to encode.
|
||||
|
||||
`target_map`:::
|
||||
(Required, object map of string:double)
|
||||
(Required, object map of string:double)
|
||||
Object that maps the field value to the target mean value.
|
||||
======
|
||||
//End target mean encoding
|
||||
|
@ -145,7 +145,7 @@ Object that maps the field value to the target mean value.
|
|||
|
||||
//Begin trained model
|
||||
`trained_model`::
|
||||
(Required, object)
|
||||
(Required, object)
|
||||
The definition of the trained model.
|
||||
+
|
||||
.Properties of `trained_model`
|
||||
|
@ -153,14 +153,14 @@ The definition of the trained model.
|
|||
=====
|
||||
//Begin tree
|
||||
`tree`::
|
||||
(Required, object)
|
||||
(Required, object)
|
||||
The definition for a binary decision tree.
|
||||
+
|
||||
.Properties of `tree`
|
||||
[%collapsible%open]
|
||||
======
|
||||
`classification_labels`:::
|
||||
(Optional, string) An array of classification labels (used for
|
||||
(Optional, string) An array of classification labels (used for
|
||||
`classification`).
|
||||
|
||||
`feature_names`:::
|
||||
|
@ -168,26 +168,26 @@ The definition for a binary decision tree.
|
|||
Features expected by the tree, in their expected order.
|
||||
|
||||
`target_type`:::
|
||||
(Required, string)
|
||||
(Required, string)
|
||||
String indicating the model target type; `regression` or `classification`.
|
||||
|
||||
`tree_structure`:::
|
||||
(Required, object)
|
||||
An array of `tree_node` objects. The nodes must be in ordinal order by their
|
||||
(Required, object)
|
||||
An array of `tree_node` objects. The nodes must be in ordinal order by their
|
||||
`tree_node.node_index` value.
|
||||
======
|
||||
//End tree
|
||||
|
||||
//Begin tree node
|
||||
`tree_node`::
|
||||
(Required, object)
|
||||
(Required, object)
|
||||
The definition of a node in a tree.
|
||||
+
|
||||
--
|
||||
There are two major types of nodes: leaf nodes and not-leaf nodes.
|
||||
|
||||
* Leaf nodes only need `node_index` and `leaf_value` defined.
|
||||
* All other nodes need `split_feature`, `left_child`, `right_child`,
|
||||
* All other nodes need `split_feature`, `left_child`, `right_child`,
|
||||
`threshold`, `decision_type`, and `default_left` defined.
|
||||
--
|
||||
+
|
||||
|
@ -195,41 +195,41 @@ There are two major types of nodes: leaf nodes and not-leaf nodes.
|
|||
[%collapsible%open]
|
||||
======
|
||||
`decision_type`::
|
||||
(Optional, string)
|
||||
Indicates the positive value (in other words, when to choose the left node)
|
||||
(Optional, string)
|
||||
Indicates the positive value (in other words, when to choose the left node)
|
||||
decision type. Supported `lt`, `lte`, `gt`, `gte`. Defaults to `lte`.
|
||||
|
||||
`default_left`::
|
||||
(Optional, boolean)
|
||||
Indicates whether to default to the left when the feature is missing. Defaults
|
||||
(Optional, boolean)
|
||||
Indicates whether to default to the left when the feature is missing. Defaults
|
||||
to `true`.
|
||||
|
||||
`leaf_value`::
|
||||
(Optional, double)
|
||||
The leaf value of the of the node, if the value is a leaf (in other words, no
|
||||
(Optional, double)
|
||||
The leaf value of the of the node, if the value is a leaf (in other words, no
|
||||
children).
|
||||
|
||||
`left_child`::
|
||||
(Optional, integer)
|
||||
(Optional, integer)
|
||||
The index of the left child.
|
||||
|
||||
`node_index`::
|
||||
(Integer)
|
||||
(Integer)
|
||||
The index of the current node.
|
||||
|
||||
`right_child`::
|
||||
(Optional, integer)
|
||||
(Optional, integer)
|
||||
The index of the right child.
|
||||
|
||||
`split_feature`::
|
||||
(Optional, integer)
|
||||
(Optional, integer)
|
||||
The index of the feature value in the feature array.
|
||||
|
||||
`split_gain`::
|
||||
(Optional, double) The information gain from the split.
|
||||
|
||||
`threshold`::
|
||||
(Optional, double)
|
||||
(Optional, double)
|
||||
The decision threshold with which to compare the feature value.
|
||||
======
|
||||
//End tree node
|
||||
|
@ -244,9 +244,9 @@ The definition for an ensemble model. See <<ml-put-inference-model-example>>.
|
|||
======
|
||||
//Begin aggregate output
|
||||
`aggregate_output`::
|
||||
(Required, object)
|
||||
An aggregated output object that defines how to aggregate the outputs of the
|
||||
`trained_models`. Supported objects are `weighted_mode`, `weighted_sum`, and
|
||||
(Required, object)
|
||||
An aggregated output object that defines how to aggregate the outputs of the
|
||||
`trained_models`. Supported objects are `weighted_mode`, `weighted_sum`, and
|
||||
`logistic_regression`. See <<ml-put-inference-aggregated-output-example>>.
|
||||
+
|
||||
.Properties of `aggregate_output`
|
||||
|
@ -254,65 +254,82 @@ An aggregated output object that defines how to aggregate the outputs of the
|
|||
=======
|
||||
//Begin logistic regression
|
||||
`logistic_regression`::
|
||||
(Optional, object)
|
||||
This `aggregated_output` type works with binary classification (classification
|
||||
for values [0, 1]). It multiplies the outputs (in the case of the `ensemble`
|
||||
model, the inference model values) by the supplied `weights`. The resulting
|
||||
vector is summed and passed to a
|
||||
https://en.wikipedia.org/wiki/Sigmoid_function[`sigmoid` function]. The result
|
||||
of the `sigmoid` function is considered the probability of class 1 (`P_1`),
|
||||
consequently, the probability of class 0 is `1 - P_1`. The class with the
|
||||
highest probability (either 0 or 1) is then returned. For more information about
|
||||
logistic regression, see
|
||||
(Optional, object)
|
||||
This `aggregated_output` type works with binary classification (classification
|
||||
for values [0, 1]). It multiplies the outputs (in the case of the `ensemble`
|
||||
model, the inference model values) by the supplied `weights`. The resulting
|
||||
vector is summed and passed to a
|
||||
https://en.wikipedia.org/wiki/Sigmoid_function[`sigmoid` function]. The result
|
||||
of the `sigmoid` function is considered the probability of class 1 (`P_1`),
|
||||
consequently, the probability of class 0 is `1 - P_1`. The class with the
|
||||
highest probability (either 0 or 1) is then returned. For more information about
|
||||
logistic regression, see
|
||||
https://en.wikipedia.org/wiki/Logistic_regression[this wiki article].
|
||||
+
|
||||
.Properties of `logistic_regression`
|
||||
[%collapsible%open]
|
||||
========
|
||||
`weights`:::
|
||||
(Required, double)
|
||||
The weights to multiply by the input values (the inference values of the trained
|
||||
(Required, double)
|
||||
The weights to multiply by the input values (the inference values of the trained
|
||||
models).
|
||||
========
|
||||
//End logistic regression
|
||||
|
||||
//Begin weighted sum
|
||||
`weighted_sum`::
|
||||
(Optional, object)
|
||||
This `aggregated_output` type works with regression. The weighted sum of the
|
||||
(Optional, object)
|
||||
This `aggregated_output` type works with regression. The weighted sum of the
|
||||
input values.
|
||||
+
|
||||
.Properties of `weighted_sum`
|
||||
[%collapsible%open]
|
||||
========
|
||||
`weights`:::
|
||||
(Required, double)
|
||||
The weights to multiply by the input values (the inference values of the trained
|
||||
(Required, double)
|
||||
The weights to multiply by the input values (the inference values of the trained
|
||||
models).
|
||||
========
|
||||
//End weighted sum
|
||||
|
||||
//Begin weighted mode
|
||||
`weighted_mode`::
|
||||
(Optional, object)
|
||||
This `aggregated_output` type works with regression or classification. It takes
|
||||
a weighted vote of the input values. The most common input value (taking the
|
||||
(Optional, object)
|
||||
This `aggregated_output` type works with regression or classification. It takes
|
||||
a weighted vote of the input values. The most common input value (taking the
|
||||
weights into account) is returned.
|
||||
+
|
||||
.Properties of `weighted_mode`
|
||||
[%collapsible%open]
|
||||
========
|
||||
`weights`:::
|
||||
(Required, double)
|
||||
The weights to multiply by the input values (the inference values of the trained
|
||||
(Required, double)
|
||||
The weights to multiply by the input values (the inference values of the trained
|
||||
models).
|
||||
========
|
||||
//End weighted mode
|
||||
|
||||
//Begin exponent
|
||||
`exponent`::
|
||||
(Optional, object)
|
||||
This `aggregated_output` type works with regression. It takes a weighted sum of
|
||||
the input values and passes the result to an exponent function
|
||||
(`e^x` where `x` is the sum of the weighted values).
|
||||
+
|
||||
.Properties of `exponent`
|
||||
[%collapsible%open]
|
||||
========
|
||||
`weights`:::
|
||||
(Required, double)
|
||||
The weights to multiply by the input values (the inference values of the trained
|
||||
models).
|
||||
========
|
||||
//End exponent
|
||||
=======
|
||||
//End aggregate output
|
||||
|
||||
`classification_labels`::
|
||||
(Optional, string)
|
||||
(Optional, string)
|
||||
An array of classification labels.
|
||||
|
||||
`feature_names`::
|
||||
|
@ -320,12 +337,12 @@ An array of classification labels.
|
|||
Features expected by the ensemble, in their expected order.
|
||||
|
||||
`target_type`::
|
||||
(Required, string)
|
||||
(Required, string)
|
||||
String indicating the model target type; `regression` or `classification.`
|
||||
|
||||
`trained_models`::
|
||||
(Required, object)
|
||||
An array of `trained_model` objects. Supported trained models are `tree` and
|
||||
An array of `trained_model` objects. Supported trained models are `tree` and
|
||||
`ensemble`.
|
||||
======
|
||||
//End ensemble
|
||||
|
@ -337,7 +354,7 @@ An array of `trained_model` objects. Supported trained models are `tree` and
|
|||
//End definition
|
||||
|
||||
`description`::
|
||||
(Optional, string)
|
||||
(Optional, string)
|
||||
A human-readable description of the {infer} trained model.
|
||||
|
||||
//Begin inference_config
|
||||
|
@ -398,24 +415,24 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-classification
|
|||
|
||||
//Begin input
|
||||
`input`::
|
||||
(Required, object)
|
||||
(Required, object)
|
||||
The input field names for the model definition.
|
||||
+
|
||||
.Properties of `input`
|
||||
[%collapsible%open]
|
||||
====
|
||||
`field_names`:::
|
||||
(Required, string)
|
||||
(Required, string)
|
||||
An array of input field names for the model.
|
||||
====
|
||||
//End input
|
||||
|
||||
`metadata`::
|
||||
(Optional, object)
|
||||
(Optional, object)
|
||||
An object map that contains metadata about the model.
|
||||
|
||||
`tags`::
|
||||
(Optional, string)
|
||||
(Optional, string)
|
||||
An array of tags to organize the model.
|
||||
|
||||
|
||||
|
@ -451,10 +468,10 @@ The next example shows a `one_hot_encoding` preprocessor object:
|
|||
|
||||
[source,js]
|
||||
----------------------------------
|
||||
{
|
||||
"one_hot_encoding":{
|
||||
{
|
||||
"one_hot_encoding":{
|
||||
"field":"FlightDelayType",
|
||||
"hot_map":{
|
||||
"hot_map":{
|
||||
"Carrier Delay":"FlightDelayType_Carrier Delay",
|
||||
"NAS Delay":"FlightDelayType_NAS Delay",
|
||||
"No Delay":"FlightDelayType_No Delay",
|
||||
|
@ -521,7 +538,7 @@ The first example shows a `trained_model` object:
|
|||
"left_child":1,
|
||||
"right_child":2
|
||||
},
|
||||
...
|
||||
...
|
||||
{
|
||||
"node_index":9,
|
||||
"leaf_value":-27.68987349695448
|
||||
|
@ -615,8 +632,21 @@ Example of a `weighted_mode` object:
|
|||
//NOTCONSOLE
|
||||
|
||||
|
||||
Example of an `exponent` object:
|
||||
|
||||
[source,js]
|
||||
----------------------------------
|
||||
"aggregate_output" : {
|
||||
"exponent" : {
|
||||
"weights" : [1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
}
|
||||
}
|
||||
----------------------------------
|
||||
//NOTCONSOLE
|
||||
|
||||
|
||||
[[ml-put-inference-json-schema]]
|
||||
===== {infer-cap} JSON schema
|
||||
|
||||
For the full JSON schema of model {infer},
|
||||
For the full JSON schema of model {infer},
|
||||
https://github.com/elastic/ml-json-schemas[click here].
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInfe
|
|||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Exponent;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LenientlyParsedOutputAggregator;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LogisticRegression;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
|
||||
|
@ -90,6 +91,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
|
|||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedOutputAggregator.class,
|
||||
LogisticRegression.NAME,
|
||||
LogisticRegression::fromXContentLenient));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedOutputAggregator.class,
|
||||
Exponent.NAME,
|
||||
Exponent::fromXContentLenient));
|
||||
|
||||
// Model Strict
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentStrict));
|
||||
|
@ -108,6 +112,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
|
|||
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedOutputAggregator.class,
|
||||
LogisticRegression.NAME,
|
||||
LogisticRegression::fromXContentStrict));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedOutputAggregator.class,
|
||||
Exponent.NAME,
|
||||
Exponent::fromXContentStrict));
|
||||
|
||||
// Inference Configs
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedInferenceConfig.class, ClassificationConfig.NAME,
|
||||
|
@ -166,6 +173,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
|
|||
namedWriteables.add(new NamedWriteableRegistry.Entry(OutputAggregator.class,
|
||||
LogisticRegression.NAME.getPreferredName(),
|
||||
LogisticRegression::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(OutputAggregator.class,
|
||||
Exponent.NAME.getPreferredName(),
|
||||
Exponent::new));
|
||||
|
||||
// Inference Results
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
|
||||
|
|
|
@ -0,0 +1,157 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public class Exponent implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
|
||||
|
||||
public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Exponent.class);
|
||||
public static final ParseField NAME = new ParseField("exponent");
|
||||
public static final ParseField WEIGHTS = new ParseField("weights");
|
||||
|
||||
private static final ConstructingObjectParser<Exponent, Void> LENIENT_PARSER = createParser(true);
|
||||
private static final ConstructingObjectParser<Exponent, Void> STRICT_PARSER = createParser(false);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static ConstructingObjectParser<Exponent, Void> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<Exponent, Void> parser = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
lenient,
|
||||
a -> new Exponent((List<Double>)a[0]));
|
||||
parser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
|
||||
return parser;
|
||||
}
|
||||
|
||||
public static Exponent fromXContentStrict(XContentParser parser) {
|
||||
return STRICT_PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public static Exponent fromXContentLenient(XContentParser parser) {
|
||||
return LENIENT_PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final double[] weights;
|
||||
|
||||
Exponent() {
|
||||
this((List<Double>) null);
|
||||
}
|
||||
|
||||
private Exponent(List<Double> weights) {
|
||||
this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray());
|
||||
}
|
||||
|
||||
public Exponent(double[] weights) {
|
||||
this.weights = weights;
|
||||
}
|
||||
|
||||
public Exponent(StreamInput in) throws IOException {
|
||||
if (in.readBoolean()) {
|
||||
this.weights = in.readDoubleArray();
|
||||
} else {
|
||||
this.weights = null;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer expectedValueSize() {
|
||||
return this.weights == null ? null : this.weights.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double[] processValues(double[][] values) {
|
||||
Objects.requireNonNull(values, "values must not be null");
|
||||
if (weights != null && values.length != weights.length) {
|
||||
throw new IllegalArgumentException("values must be the same length as weights.");
|
||||
}
|
||||
assert values[0].length == 1;
|
||||
double[] processed = new double[values.length];
|
||||
for (int i = 0; i < values.length; ++i) {
|
||||
if (weights != null) {
|
||||
processed[i] = weights[i] * values[i][0];
|
||||
} else {
|
||||
processed[i] = values[i][0];
|
||||
}
|
||||
}
|
||||
return processed;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double aggregate(double[] values) {
|
||||
Objects.requireNonNull(values, "values must not be null");
|
||||
double sum = 0.0;
|
||||
for (double val : values) {
|
||||
if (Double.isFinite(val)) {
|
||||
sum += val;
|
||||
}
|
||||
}
|
||||
return Math.exp(sum);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean compatibleWith(TargetType targetType) {
|
||||
return TargetType.REGRESSION.equals(targetType);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeBoolean(weights != null);
|
||||
if (weights != null) {
|
||||
out.writeDoubleArray(weights);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (weights != null) {
|
||||
builder.field(WEIGHTS.getPreferredName(), weights);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Exponent that = (Exponent) o;
|
||||
return Arrays.equals(weights, that.weights);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Arrays.hashCode(weights);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights);
|
||||
return SHALLOW_SIZE + weightSize;
|
||||
}
|
||||
}
|
|
@ -80,7 +80,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
categoryLabels = randomList(2, randomIntBetween(3, 10), () -> randomAlphaOfLength(10));
|
||||
}
|
||||
|
||||
OutputAggregator outputAggregator = targetType == TargetType.REGRESSION ? new WeightedSum(weights) :
|
||||
OutputAggregator outputAggregator = targetType == TargetType.REGRESSION ?
|
||||
randomFrom(new WeightedSum(weights), new Exponent(weights)) :
|
||||
randomFrom(
|
||||
new WeightedMode(
|
||||
weights,
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
|
||||
public class ExponentTests extends WeightedAggregatorTests<Exponent> {
|
||||
|
||||
@Override
|
||||
Exponent createTestInstance(int numberOfWeights) {
|
||||
double[] weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).mapToDouble(Double::valueOf).toArray();
|
||||
return new Exponent(weights);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Exponent doParseInstance(XContentParser parser) throws IOException {
|
||||
return lenient ? Exponent.fromXContentLenient(parser) : Exponent.fromXContentStrict(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Exponent createTestInstance() {
|
||||
return randomBoolean() ? new Exponent() : createTestInstance(randomIntBetween(1, 100));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Exponent> instanceReader() {
|
||||
return Exponent::new;
|
||||
}
|
||||
|
||||
public void testAggregate() {
|
||||
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
|
||||
double[][] values = new double[][]{
|
||||
new double[] {.01},
|
||||
new double[] {.2},
|
||||
new double[] {.002},
|
||||
new double[] {-.01},
|
||||
new double[] {.1}
|
||||
};
|
||||
|
||||
Exponent exponent = new Exponent(ones);
|
||||
assertThat(exponent.aggregate(exponent.processValues(values)), closeTo(1.35256, 0.00001));
|
||||
|
||||
double[] variedWeights = new double[]{.01, -1.0, .1, 0.0, 0.0};
|
||||
|
||||
exponent = new Exponent(variedWeights);
|
||||
assertThat(exponent.aggregate(exponent.processValues(values)), closeTo(0.81897, 0.00001));
|
||||
|
||||
exponent = new Exponent();
|
||||
assertThat(exponent.aggregate(exponent.processValues(values)), closeTo(1.35256, 0.00001));
|
||||
}
|
||||
|
||||
public void testCompatibleWith() {
|
||||
Exponent exponent = createTestInstance();
|
||||
assertThat(exponent.compatibleWith(TargetType.CLASSIFICATION), is(false));
|
||||
assertThat(exponent.compatibleWith(TargetType.REGRESSION), is(true));
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue