mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-22 12:56:53 +00:00
Adds InferencePipelineAggregationBuilder to the HLRC duplicating the server side classes
This commit is contained in:
parent
d56fc72ee5
commit
c5443f78ce
@ -54,6 +54,8 @@ import org.elasticsearch.action.search.SearchScrollRequest;
|
||||
import org.elasticsearch.action.support.master.AcknowledgedResponse;
|
||||
import org.elasticsearch.action.update.UpdateRequest;
|
||||
import org.elasticsearch.action.update.UpdateResponse;
|
||||
import org.elasticsearch.client.analytics.InferencePipelineAggregationBuilder;
|
||||
import org.elasticsearch.client.analytics.ParsedInference;
|
||||
import org.elasticsearch.client.analytics.ParsedStringStats;
|
||||
import org.elasticsearch.client.analytics.ParsedTopMetrics;
|
||||
import org.elasticsearch.client.analytics.StringStatsAggregationBuilder;
|
||||
@ -1957,6 +1959,7 @@ public class RestHighLevelClient implements Closeable {
|
||||
map.put(CompositeAggregationBuilder.NAME, (p, c) -> ParsedComposite.fromXContent(p, (String) c));
|
||||
map.put(StringStatsAggregationBuilder.NAME, (p, c) -> ParsedStringStats.PARSER.parse(p, (String) c));
|
||||
map.put(TopMetricsAggregationBuilder.NAME, (p, c) -> ParsedTopMetrics.PARSER.parse(p, (String) c));
|
||||
map.put(InferencePipelineAggregationBuilder.NAME, (p, c) -> ParsedInference.fromXContent(p, (String ) (c)));
|
||||
List<NamedXContentRegistry.Entry> entries = map.entrySet().stream()
|
||||
.map(entry -> new NamedXContentRegistry.Entry(Aggregation.class, new ParseField(entry.getKey()), entry.getValue()))
|
||||
.collect(Collectors.toList());
|
||||
|
@ -0,0 +1,141 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.client.analytics;
|
||||
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.TreeMap;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
|
||||
/**
|
||||
* For building inference pipeline aggregations
|
||||
*
|
||||
* NOTE: This extends {@linkplain AbstractPipelineAggregationBuilder} for compatibility
|
||||
* with {@link SearchSourceBuilder#aggregation(PipelineAggregationBuilder)} but it
|
||||
* doesn't support any "server" side things like {@linkplain #doWriteTo(StreamOutput)}
|
||||
* or {@linkplain #createInternal(Map)}
|
||||
*/
|
||||
public class InferencePipelineAggregationBuilder extends AbstractPipelineAggregationBuilder<InferencePipelineAggregationBuilder> {
|
||||
|
||||
public static String NAME = "inference";
|
||||
|
||||
public static final ParseField MODEL_ID = new ParseField("model_id");
|
||||
private static final ParseField INFERENCE_CONFIG = new ParseField("inference_config");
|
||||
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<InferencePipelineAggregationBuilder, String> PARSER = new ConstructingObjectParser<>(
|
||||
NAME, false,
|
||||
(args, name) -> new InferencePipelineAggregationBuilder(name, (String)args[0], (Map<String, String>) args[1])
|
||||
);
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), MODEL_ID);
|
||||
PARSER.declareObject(constructorArg(), (p, c) -> p.mapStrings(), BUCKETS_PATH_FIELD);
|
||||
PARSER.declareNamedObject(InferencePipelineAggregationBuilder::setInferenceConfig,
|
||||
(p, c, n) -> p.namedObject(InferenceConfig.class, n, c), INFERENCE_CONFIG);
|
||||
}
|
||||
|
||||
private final Map<String, String> bucketPathMap;
|
||||
private final String modelId;
|
||||
private InferenceConfig inferenceConfig;
|
||||
|
||||
public static InferencePipelineAggregationBuilder parse(String pipelineAggregatorName,
|
||||
XContentParser parser) {
|
||||
return PARSER.apply(parser, pipelineAggregatorName);
|
||||
}
|
||||
|
||||
public InferencePipelineAggregationBuilder(String name, String modelId, Map<String, String> bucketsPath) {
|
||||
super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {}));
|
||||
this.modelId = modelId;
|
||||
this.bucketPathMap = bucketsPath;
|
||||
}
|
||||
|
||||
public void setInferenceConfig(InferenceConfig inferenceConfig) {
|
||||
this.inferenceConfig = inferenceConfig;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void validate(ValidationContext context) {
|
||||
// validation occurs on the server
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doWriteTo(StreamOutput out) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected PipelineAggregator createInternal(Map<String, Object> metaData) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean overrideBucketsPath() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.field(MODEL_ID.getPreferredName(), modelId);
|
||||
builder.field(BUCKETS_PATH_FIELD.getPreferredName(), bucketPathMap);
|
||||
if (inferenceConfig != null) {
|
||||
builder.startObject(INFERENCE_CONFIG.getPreferredName());
|
||||
builder.field(inferenceConfig.getName(), inferenceConfig);
|
||||
builder.endObject();
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(super.hashCode(), bucketPathMap, modelId, inferenceConfig);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (this == obj) return true;
|
||||
if (obj == null || getClass() != obj.getClass()) return false;
|
||||
if (super.equals(obj) == false) return false;
|
||||
|
||||
InferencePipelineAggregationBuilder other = (InferencePipelineAggregationBuilder) obj;
|
||||
return Objects.equals(bucketPathMap, other.bucketPathMap)
|
||||
&& Objects.equals(modelId, other.modelId)
|
||||
&& Objects.equals(inferenceConfig, other.inferenceConfig);
|
||||
}
|
||||
}
|
@ -0,0 +1,137 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.client.analytics;
|
||||
|
||||
import org.elasticsearch.client.ml.inference.results.FeatureImportance;
|
||||
import org.elasticsearch.client.ml.inference.results.TopClassEntry;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParseException;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.aggregations.ParsedAggregation;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* This class parses the superset of all possible fields that may be written by
|
||||
* InferenceResults. The warning field is mutually exclusive with all the other fields.
|
||||
*
|
||||
* In the case of classification results {@link #getValue()} may return a String,
|
||||
* Boolean or a Double. For regression results {@link #getValue()} is always
|
||||
* a Double.
|
||||
*/
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
public class ParsedInference extends ParsedAggregation {
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<ParsedInference, Void> PARSER =
|
||||
new ConstructingObjectParser<>(ParsedInference.class.getSimpleName(), true,
|
||||
args -> new ParsedInference(args[0], (List<FeatureImportance>) args[1],
|
||||
(List<TopClassEntry>) args[2], (String) args[3]));
|
||||
|
||||
public static final ParseField FEATURE_IMPORTANCE = new ParseField("feature_importance");
|
||||
public static final ParseField WARNING = new ParseField("warning");
|
||||
public static final ParseField TOP_CLASSES = new ParseField("top_classes");
|
||||
|
||||
static {
|
||||
PARSER.declareField(optionalConstructorArg(), (p, n) -> {
|
||||
Object o;
|
||||
XContentParser.Token token = p.currentToken();
|
||||
if (token == XContentParser.Token.VALUE_STRING) {
|
||||
o = p.text();
|
||||
} else if (token == XContentParser.Token.VALUE_BOOLEAN) {
|
||||
o = p.booleanValue();
|
||||
} else if (token == XContentParser.Token.VALUE_NUMBER) {
|
||||
o = p.doubleValue();
|
||||
} else {
|
||||
throw new XContentParseException(p.getTokenLocation(),
|
||||
"[" + ParsedInference.class.getSimpleName() + "] failed to parse field [" + CommonFields.VALUE + "] "
|
||||
+ "value [" + token + "] is not a string, boolean or number");
|
||||
}
|
||||
return o;
|
||||
}, CommonFields.VALUE, ObjectParser.ValueType.VALUE);
|
||||
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> FeatureImportance.fromXContent(p), FEATURE_IMPORTANCE);
|
||||
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> TopClassEntry.fromXContent(p), TOP_CLASSES);
|
||||
PARSER.declareString(optionalConstructorArg(), WARNING);
|
||||
declareAggregationFields(PARSER);
|
||||
}
|
||||
|
||||
public static ParsedInference fromXContent(XContentParser parser, final String name) {
|
||||
ParsedInference parsed = PARSER.apply(parser, null);
|
||||
parsed.setName(name);
|
||||
return parsed;
|
||||
}
|
||||
|
||||
private final Object value;
|
||||
private final List<FeatureImportance> featureImportance;
|
||||
private final List<TopClassEntry> topClasses;
|
||||
private final String warning;
|
||||
|
||||
ParsedInference(Object value,
|
||||
List<FeatureImportance> featureImportance,
|
||||
List<TopClassEntry> topClasses,
|
||||
String warning) {
|
||||
this.value = value;
|
||||
this.warning = warning;
|
||||
this.featureImportance = featureImportance;
|
||||
this.topClasses = topClasses;
|
||||
}
|
||||
|
||||
public Object getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
public List<FeatureImportance> getFeatureImportance() {
|
||||
return featureImportance;
|
||||
}
|
||||
|
||||
public List<TopClassEntry> getTopClasses() {
|
||||
return topClasses;
|
||||
}
|
||||
|
||||
public String getWarning() {
|
||||
return warning;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
|
||||
if (warning != null) {
|
||||
builder.field(WARNING.getPreferredName(), warning);
|
||||
} else {
|
||||
builder.field(CommonFields.VALUE.getPreferredName(), value);
|
||||
if (topClasses != null && topClasses.size() > 0) {
|
||||
builder.field(TOP_CLASSES.getPreferredName(), topClasses);
|
||||
}
|
||||
if (featureImportance != null && featureImportance.size() > 0) {
|
||||
builder.field(FEATURE_IMPORTANCE.getPreferredName(), featureImportance);
|
||||
}
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getType() {
|
||||
return InferencePipelineAggregationBuilder.NAME;
|
||||
}
|
||||
}
|
@ -0,0 +1,112 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.client.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
public class FeatureImportance implements ToXContentObject {
|
||||
|
||||
public static final String IMPORTANCE = "importance";
|
||||
public static final String FEATURE_NAME = "feature_name";
|
||||
public static final String CLASS_IMPORTANCE = "class_importance";
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<FeatureImportance, Void> PARSER =
|
||||
new ConstructingObjectParser<>("feature_importance", true,
|
||||
a -> new FeatureImportance((String) a[0], (Double) a[1], (Map<String, Double>) a[2])
|
||||
);
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
|
||||
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
|
||||
PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
|
||||
new ParseField(FeatureImportance.CLASS_IMPORTANCE));
|
||||
}
|
||||
|
||||
public static FeatureImportance fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final Map<String, Double> classImportance;
|
||||
private final double importance;
|
||||
private final String featureName;
|
||||
|
||||
public FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
|
||||
this.featureName = Objects.requireNonNull(featureName);
|
||||
this.importance = importance;
|
||||
this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance);
|
||||
}
|
||||
|
||||
public Map<String, Double> getClassImportance() {
|
||||
return classImportance;
|
||||
}
|
||||
|
||||
public double getImportance() {
|
||||
return importance;
|
||||
}
|
||||
|
||||
public String getFeatureName() {
|
||||
return featureName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FEATURE_NAME, featureName);
|
||||
builder.field(IMPORTANCE, importance);
|
||||
if (classImportance != null && classImportance.isEmpty() == false) {
|
||||
builder.startObject(CLASS_IMPORTANCE);
|
||||
for (Map.Entry<String, Double> entry : classImportance.entrySet()) {
|
||||
builder.field(entry.getKey(), entry.getValue());
|
||||
}
|
||||
builder.endObject();
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object object) {
|
||||
if (object == this) { return true; }
|
||||
if (object == null || getClass() != object.getClass()) { return false; }
|
||||
FeatureImportance that = (FeatureImportance) object;
|
||||
return Objects.equals(featureName, that.featureName)
|
||||
&& Objects.equals(importance, that.importance)
|
||||
&& Objects.equals(classImportance, that.classImportance);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(featureName, importance, classImportance);
|
||||
}
|
||||
}
|
@ -0,0 +1,116 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.client.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParseException;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
|
||||
public class TopClassEntry implements ToXContentObject {
|
||||
|
||||
public static final ParseField CLASS_NAME = new ParseField("class_name");
|
||||
public static final ParseField CLASS_PROBABILITY = new ParseField("class_probability");
|
||||
public static final ParseField CLASS_SCORE = new ParseField("class_score");
|
||||
|
||||
public static final String NAME = "top_class";
|
||||
|
||||
private static final ConstructingObjectParser<TopClassEntry, Void> PARSER =
|
||||
new ConstructingObjectParser<>(NAME, true, a -> new TopClassEntry(a[0], (Double) a[1], (Double) a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareField(constructorArg(), (p, n) -> {
|
||||
Object o;
|
||||
XContentParser.Token token = p.currentToken();
|
||||
if (token == XContentParser.Token.VALUE_STRING) {
|
||||
o = p.text();
|
||||
} else if (token == XContentParser.Token.VALUE_BOOLEAN) {
|
||||
o = p.booleanValue();
|
||||
} else if (token == XContentParser.Token.VALUE_NUMBER) {
|
||||
o = p.doubleValue();
|
||||
} else {
|
||||
throw new XContentParseException(p.getTokenLocation(),
|
||||
"[" + NAME + "] failed to parse field [" + CLASS_NAME + "] value [" + token
|
||||
+ "] is not a string, boolean or number");
|
||||
}
|
||||
return o;
|
||||
}, CLASS_NAME, ObjectParser.ValueType.VALUE);
|
||||
PARSER.declareDouble(constructorArg(), CLASS_PROBABILITY);
|
||||
PARSER.declareDouble(constructorArg(), CLASS_SCORE);
|
||||
}
|
||||
|
||||
public static TopClassEntry fromXContent(XContentParser parser) throws IOException {
|
||||
return PARSER.parse(parser, null);
|
||||
}
|
||||
|
||||
private final Object classification;
|
||||
private final double probability;
|
||||
private final double score;
|
||||
|
||||
public TopClassEntry(Object classification, double probability, double score) {
|
||||
this.classification = Objects.requireNonNull(classification);
|
||||
this.probability = probability;
|
||||
this.score = score;
|
||||
}
|
||||
|
||||
public Object getClassification() {
|
||||
return classification;
|
||||
}
|
||||
|
||||
public double getProbability() {
|
||||
return probability;
|
||||
}
|
||||
|
||||
public double getScore() {
|
||||
return score;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(CLASS_NAME.getPreferredName(), classification);
|
||||
builder.field(CLASS_PROBABILITY.getPreferredName(), probability);
|
||||
builder.field(CLASS_SCORE.getPreferredName(), score);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object object) {
|
||||
if (object == this) { return true; }
|
||||
if (object == null || getClass() != object.getClass()) { return false; }
|
||||
TopClassEntry that = (TopClassEntry) object;
|
||||
return Objects.equals(classification, that.classification) && probability == that.probability && score == that.score;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(classification, probability, score);
|
||||
}
|
||||
}
|
@ -688,6 +688,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
||||
// Explicitly check for metrics from the analytics module because they aren't in InternalAggregationTestCase
|
||||
assertTrue(namedXContents.removeIf(e -> e.name.getPreferredName().equals("string_stats")));
|
||||
assertTrue(namedXContents.removeIf(e -> e.name.getPreferredName().equals("top_metrics")));
|
||||
assertTrue(namedXContents.removeIf(e -> e.name.getPreferredName().equals("inference")));
|
||||
|
||||
assertEquals(expectedInternalAggregations + expectedSuggestions, namedXContents.size());
|
||||
Map<Class<?>, Integer> categories = new HashMap<>();
|
||||
|
@ -0,0 +1,127 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.client.analytics;
|
||||
|
||||
import org.elasticsearch.action.bulk.BulkRequest;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.search.SearchRequest;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.client.ESRestHighLevelClientTestCase;
|
||||
import org.elasticsearch.client.RequestOptions;
|
||||
import org.elasticsearch.client.indices.CreateIndexRequest;
|
||||
import org.elasticsearch.client.ml.PutTrainedModelRequest;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeNode;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.ParsedTerms;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class InferenceAggIT extends ESRestHighLevelClientTestCase {
|
||||
|
||||
public void testInferenceAgg() throws IOException {
|
||||
|
||||
// create a very simple decision tree with a root node and 2 leaves
|
||||
List<String> featureNames = Collections.singletonList("cost");
|
||||
Tree.Builder builder = Tree.builder();
|
||||
builder.setFeatureNames(featureNames);
|
||||
TreeNode.Builder root = builder.addJunction(0, 0, true, 1.0);
|
||||
int leftChild = root.getLeftChild();
|
||||
int rightChild = root.getRightChild();
|
||||
builder.addLeaf(leftChild, 10.0);
|
||||
builder.addLeaf(rightChild, 20.0);
|
||||
|
||||
final String modelId = "simple_regression";
|
||||
putTrainedModel(modelId, featureNames, builder.build());
|
||||
|
||||
final String index = "inference-test-data";
|
||||
indexData(index);
|
||||
|
||||
TermsAggregationBuilder termsAgg = new TermsAggregationBuilder("fruit_type").field("fruit");
|
||||
AvgAggregationBuilder avgAgg = new AvgAggregationBuilder("avg_cost").field("cost");
|
||||
termsAgg.subAggregation(avgAgg);
|
||||
|
||||
Map<String, String> bucketPaths = new HashMap<>();
|
||||
bucketPaths.put("cost", "avg_cost");
|
||||
InferencePipelineAggregationBuilder inferenceAgg = new InferencePipelineAggregationBuilder("infer", modelId, bucketPaths);
|
||||
termsAgg.subAggregation(inferenceAgg);
|
||||
|
||||
SearchRequest search = new SearchRequest(index);
|
||||
search.source().aggregation(termsAgg);
|
||||
SearchResponse response = highLevelClient().search(search, RequestOptions.DEFAULT);
|
||||
ParsedTerms terms = response.getAggregations().get("fruit_type");
|
||||
List<? extends Terms.Bucket> buckets = terms.getBuckets();
|
||||
{
|
||||
assertThat(buckets.get(0).getKey(), equalTo("apple"));
|
||||
ParsedInference inference = buckets.get(0).getAggregations().get("infer");
|
||||
assertThat((Double) inference.getValue(), closeTo(20.0, 0.01));
|
||||
assertNull(inference.getWarning());
|
||||
assertNull(inference.getFeatureImportance());
|
||||
assertNull(inference.getTopClasses());
|
||||
}
|
||||
{
|
||||
assertThat(buckets.get(1).getKey(), equalTo("banana"));
|
||||
ParsedInference inference = buckets.get(1).getAggregations().get("infer");
|
||||
assertThat((Double) inference.getValue(), closeTo(10.0, 0.01));
|
||||
assertNull(inference.getWarning());
|
||||
assertNull(inference.getFeatureImportance());
|
||||
assertNull(inference.getTopClasses());
|
||||
}
|
||||
}
|
||||
|
||||
private void putTrainedModel(String modelId, List<String> inputFields, Tree tree) throws IOException {
|
||||
TrainedModelDefinition definition = new TrainedModelDefinition.Builder().setTrainedModel(tree).build();
|
||||
TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
|
||||
.setDefinition(definition)
|
||||
.setModelId(modelId)
|
||||
.setInferenceConfig(new RegressionConfig())
|
||||
.setInput(new TrainedModelInput(inputFields))
|
||||
.setDescription("test model")
|
||||
.build();
|
||||
highLevelClient().machineLearning().putTrainedModel(new PutTrainedModelRequest(trainedModelConfig), RequestOptions.DEFAULT);
|
||||
}
|
||||
|
||||
private void indexData(String index) throws IOException {
|
||||
CreateIndexRequest create = new CreateIndexRequest(index);
|
||||
create.mapping("{\"properties\": {\"fruit\": {\"type\": \"keyword\"}," +
|
||||
"\"cost\": {\"type\": \"double\"}}}", XContentType.JSON);
|
||||
highLevelClient().indices().create(create, RequestOptions.DEFAULT);
|
||||
BulkRequest bulk = new BulkRequest(index).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
bulk.add(new IndexRequest().source(XContentType.JSON, "fruit", "apple", "cost", "1.2"));
|
||||
bulk.add(new IndexRequest().source(XContentType.JSON, "fruit", "banana", "cost", "0.8"));
|
||||
bulk.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
|
||||
}
|
||||
}
|
@ -54,5 +54,4 @@ public class TrainedModelInputTests extends AbstractXContentTestCase<TrainedMode
|
||||
protected TrainedModelInput createTestInstance() {
|
||||
return createRandomInput();
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -0,0 +1,59 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.client.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class FeatureImportanceTests extends AbstractXContentTestCase<FeatureImportance> {
|
||||
|
||||
@Override
|
||||
protected FeatureImportance createTestInstance() {
|
||||
return new FeatureImportance(
|
||||
randomAlphaOfLength(10),
|
||||
randomDoubleBetween(-10.0, 10.0, false),
|
||||
randomBoolean() ? null :
|
||||
Stream.generate(() -> randomAlphaOfLength(10))
|
||||
.limit(randomLongBetween(2, 10))
|
||||
.collect(Collectors.toMap(Function.identity(), (k) -> randomDoubleBetween(-10, 10, false))));
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FeatureImportance doParseInstance(XContentParser parser) throws IOException {
|
||||
return FeatureImportance.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||
return field -> field.equals(FeatureImportance.CLASS_IMPORTANCE);
|
||||
}
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.client.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class TopClassEntryTests extends AbstractXContentTestCase<TopClassEntry> {
|
||||
@Override
|
||||
protected TopClassEntry createTestInstance() {
|
||||
Object classification;
|
||||
if (randomBoolean()) {
|
||||
classification = randomAlphaOfLength(10);
|
||||
} else if (randomBoolean()) {
|
||||
classification = randomBoolean();
|
||||
} else {
|
||||
classification = randomDouble();
|
||||
}
|
||||
return new TopClassEntry(classification, randomDouble(), randomDouble());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TopClassEntry doParseInstance(XContentParser parser) throws IOException {
|
||||
return TopClassEntry.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
}
|
@ -62,6 +62,7 @@ This page lists all the available aggregations with their corresponding `Aggrega
|
||||
| Pipeline on | PipelineAggregationBuilder Class | Method in PipelineAggregatorBuilders
|
||||
| {ref}/search-aggregations-pipeline-avg-bucket-aggregation.html[Avg Bucket] | {agg-ref}/pipeline/bucketmetrics/avg/AvgBucketPipelineAggregationBuilder.html[AvgBucketPipelineAggregationBuilder] | {agg-ref}/pipeline/PipelineAggregatorBuilders.html#avgBucket-java.lang.String-java.lang.String-[PipelineAggregatorBuilders.avgBucket()]
|
||||
| {ref}/search-aggregations-pipeline-derivative-aggregation.html[Derivative] | {agg-ref}/pipeline/derivative/DerivativePipelineAggregationBuilder.html[DerivativePipelineAggregationBuilder] | {agg-ref}/pipeline/PipelineAggregatorBuilders.html#derivative-java.lang.String-java.lang.String-[PipelineAggregatorBuilders.derivative()]
|
||||
| {ref}/search-aggregations-pipeline-inference-bucket-aggregation.html[Inference] | {javadoc-client}/analytics/InferencePipelineAggregationBuilder.html[InferencePipelineAggregationBuilder] | None
|
||||
| {ref}/search-aggregations-pipeline-max-bucket-aggregation.html[Max Bucket] | {agg-ref}/pipeline/bucketmetrics/max/MaxBucketPipelineAggregationBuilder.html[MaxBucketPipelineAggregationBuilder] | {agg-ref}/pipeline/PipelineAggregatorBuilders.html#maxBucket-java.lang.String-java.lang.String-[PipelineAggregatorBuilders.maxBucket()]
|
||||
| {ref}/search-aggregations-pipeline-min-bucket-aggregation.html[Min Bucket] | {agg-ref}/pipeline/bucketmetrics/min/MinBucketPipelineAggregationBuilder.html[MinBucketPipelineAggregationBuilder] | {agg-ref}/pipeline/PipelineAggregatorBuilders.html#minBucket-java.lang.String-java.lang.String-[PipelineAggregatorBuilders.minBucket()]
|
||||
| {ref}/search-aggregations-pipeline-sum-bucket-aggregation.html[Sum Bucket] | {agg-ref}/pipeline/bucketmetrics/sum/SumBucketPipelineAggregationBuilder.html[SumBucketPipelineAggregationBuilder] | {agg-ref}/pipeline/PipelineAggregatorBuilders.html#sumBucket-java.lang.String-java.lang.String-[PipelineAggregatorBuilders.sumBucket()]
|
||||
|
Loading…
x
Reference in New Issue
Block a user