Add Inference Pipeline aggregation to HLRC (#59086) (#59250)

Adds InferencePipelineAggregationBuilder to the HLRC duplicating 
the server side classes
This commit is contained in:
David Kyle 2020-07-09 13:38:45 +01:00 committed by GitHub
parent d56fc72ee5
commit c5443f78ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 747 additions and 1 deletions

View File

@ -54,6 +54,8 @@ import org.elasticsearch.action.search.SearchScrollRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.action.update.UpdateResponse;
import org.elasticsearch.client.analytics.InferencePipelineAggregationBuilder;
import org.elasticsearch.client.analytics.ParsedInference;
import org.elasticsearch.client.analytics.ParsedStringStats;
import org.elasticsearch.client.analytics.ParsedTopMetrics;
import org.elasticsearch.client.analytics.StringStatsAggregationBuilder;
@ -1957,6 +1959,7 @@ public class RestHighLevelClient implements Closeable {
map.put(CompositeAggregationBuilder.NAME, (p, c) -> ParsedComposite.fromXContent(p, (String) c));
map.put(StringStatsAggregationBuilder.NAME, (p, c) -> ParsedStringStats.PARSER.parse(p, (String) c));
map.put(TopMetricsAggregationBuilder.NAME, (p, c) -> ParsedTopMetrics.PARSER.parse(p, (String) c));
map.put(InferencePipelineAggregationBuilder.NAME, (p, c) -> ParsedInference.fromXContent(p, (String ) (c)));
List<NamedXContentRegistry.Entry> entries = map.entrySet().stream()
.map(entry -> new NamedXContentRegistry.Entry(Aggregation.class, new ParseField(entry.getKey()), entry.getValue()))
.collect(Collectors.toList());

View File

@ -0,0 +1,141 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.analytics;
import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
/**
* For building inference pipeline aggregations
*
* NOTE: This extends {@linkplain AbstractPipelineAggregationBuilder} for compatibility
* with {@link SearchSourceBuilder#aggregation(PipelineAggregationBuilder)} but it
* doesn't support any "server" side things like {@linkplain #doWriteTo(StreamOutput)}
* or {@linkplain #createInternal(Map)}
*/
public class InferencePipelineAggregationBuilder extends AbstractPipelineAggregationBuilder<InferencePipelineAggregationBuilder> {
public static String NAME = "inference";
public static final ParseField MODEL_ID = new ParseField("model_id");
private static final ParseField INFERENCE_CONFIG = new ParseField("inference_config");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<InferencePipelineAggregationBuilder, String> PARSER = new ConstructingObjectParser<>(
NAME, false,
(args, name) -> new InferencePipelineAggregationBuilder(name, (String)args[0], (Map<String, String>) args[1])
);
static {
PARSER.declareString(constructorArg(), MODEL_ID);
PARSER.declareObject(constructorArg(), (p, c) -> p.mapStrings(), BUCKETS_PATH_FIELD);
PARSER.declareNamedObject(InferencePipelineAggregationBuilder::setInferenceConfig,
(p, c, n) -> p.namedObject(InferenceConfig.class, n, c), INFERENCE_CONFIG);
}
private final Map<String, String> bucketPathMap;
private final String modelId;
private InferenceConfig inferenceConfig;
public static InferencePipelineAggregationBuilder parse(String pipelineAggregatorName,
XContentParser parser) {
return PARSER.apply(parser, pipelineAggregatorName);
}
public InferencePipelineAggregationBuilder(String name, String modelId, Map<String, String> bucketsPath) {
super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {}));
this.modelId = modelId;
this.bucketPathMap = bucketsPath;
}
public void setInferenceConfig(InferenceConfig inferenceConfig) {
this.inferenceConfig = inferenceConfig;
}
@Override
protected void validate(ValidationContext context) {
// validation occurs on the server
}
@Override
protected void doWriteTo(StreamOutput out) {
throw new UnsupportedOperationException();
}
@Override
protected PipelineAggregator createInternal(Map<String, Object> metaData) {
throw new UnsupportedOperationException();
}
@Override
protected boolean overrideBucketsPath() {
return true;
}
@Override
protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(MODEL_ID.getPreferredName(), modelId);
builder.field(BUCKETS_PATH_FIELD.getPreferredName(), bucketPathMap);
if (inferenceConfig != null) {
builder.startObject(INFERENCE_CONFIG.getPreferredName());
builder.field(inferenceConfig.getName(), inferenceConfig);
builder.endObject();
}
return builder;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), bucketPathMap, modelId, inferenceConfig);
}
@Override
public boolean equals(Object obj) {
if (this == obj) return true;
if (obj == null || getClass() != obj.getClass()) return false;
if (super.equals(obj) == false) return false;
InferencePipelineAggregationBuilder other = (InferencePipelineAggregationBuilder) obj;
return Objects.equals(bucketPathMap, other.bucketPathMap)
&& Objects.equals(modelId, other.modelId)
&& Objects.equals(inferenceConfig, other.inferenceConfig);
}
}

View File

@ -0,0 +1,137 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.analytics;
import org.elasticsearch.client.ml.inference.results.FeatureImportance;
import org.elasticsearch.client.ml.inference.results.TopClassEntry;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.ParsedAggregation;
import java.io.IOException;
import java.util.List;
/**
* This class parses the superset of all possible fields that may be written by
* InferenceResults. The warning field is mutually exclusive with all the other fields.
*
* In the case of classification results {@link #getValue()} may return a String,
* Boolean or a Double. For regression results {@link #getValue()} is always
* a Double.
*/
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class ParsedInference extends ParsedAggregation {
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<ParsedInference, Void> PARSER =
new ConstructingObjectParser<>(ParsedInference.class.getSimpleName(), true,
args -> new ParsedInference(args[0], (List<FeatureImportance>) args[1],
(List<TopClassEntry>) args[2], (String) args[3]));
public static final ParseField FEATURE_IMPORTANCE = new ParseField("feature_importance");
public static final ParseField WARNING = new ParseField("warning");
public static final ParseField TOP_CLASSES = new ParseField("top_classes");
static {
PARSER.declareField(optionalConstructorArg(), (p, n) -> {
Object o;
XContentParser.Token token = p.currentToken();
if (token == XContentParser.Token.VALUE_STRING) {
o = p.text();
} else if (token == XContentParser.Token.VALUE_BOOLEAN) {
o = p.booleanValue();
} else if (token == XContentParser.Token.VALUE_NUMBER) {
o = p.doubleValue();
} else {
throw new XContentParseException(p.getTokenLocation(),
"[" + ParsedInference.class.getSimpleName() + "] failed to parse field [" + CommonFields.VALUE + "] "
+ "value [" + token + "] is not a string, boolean or number");
}
return o;
}, CommonFields.VALUE, ObjectParser.ValueType.VALUE);
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> FeatureImportance.fromXContent(p), FEATURE_IMPORTANCE);
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> TopClassEntry.fromXContent(p), TOP_CLASSES);
PARSER.declareString(optionalConstructorArg(), WARNING);
declareAggregationFields(PARSER);
}
public static ParsedInference fromXContent(XContentParser parser, final String name) {
ParsedInference parsed = PARSER.apply(parser, null);
parsed.setName(name);
return parsed;
}
private final Object value;
private final List<FeatureImportance> featureImportance;
private final List<TopClassEntry> topClasses;
private final String warning;
ParsedInference(Object value,
List<FeatureImportance> featureImportance,
List<TopClassEntry> topClasses,
String warning) {
this.value = value;
this.warning = warning;
this.featureImportance = featureImportance;
this.topClasses = topClasses;
}
public Object getValue() {
return value;
}
public List<FeatureImportance> getFeatureImportance() {
return featureImportance;
}
public List<TopClassEntry> getTopClasses() {
return topClasses;
}
public String getWarning() {
return warning;
}
@Override
protected XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
if (warning != null) {
builder.field(WARNING.getPreferredName(), warning);
} else {
builder.field(CommonFields.VALUE.getPreferredName(), value);
if (topClasses != null && topClasses.size() > 0) {
builder.field(TOP_CLASSES.getPreferredName(), topClasses);
}
if (featureImportance != null && featureImportance.size() > 0) {
builder.field(FEATURE_IMPORTANCE.getPreferredName(), featureImportance);
}
}
return builder;
}
@Override
public String getType() {
return InferencePipelineAggregationBuilder.NAME;
}
}

View File

@ -0,0 +1,112 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.results;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class FeatureImportance implements ToXContentObject {
public static final String IMPORTANCE = "importance";
public static final String FEATURE_NAME = "feature_name";
public static final String CLASS_IMPORTANCE = "class_importance";
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<FeatureImportance, Void> PARSER =
new ConstructingObjectParser<>("feature_importance", true,
a -> new FeatureImportance((String) a[0], (Double) a[1], (Map<String, Double>) a[2])
);
static {
PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
new ParseField(FeatureImportance.CLASS_IMPORTANCE));
}
public static FeatureImportance fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final Map<String, Double> classImportance;
private final double importance;
private final String featureName;
public FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
this.featureName = Objects.requireNonNull(featureName);
this.importance = importance;
this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance);
}
public Map<String, Double> getClassImportance() {
return classImportance;
}
public double getImportance() {
return importance;
}
public String getFeatureName() {
return featureName;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FEATURE_NAME, featureName);
builder.field(IMPORTANCE, importance);
if (classImportance != null && classImportance.isEmpty() == false) {
builder.startObject(CLASS_IMPORTANCE);
for (Map.Entry<String, Double> entry : classImportance.entrySet()) {
builder.field(entry.getKey(), entry.getValue());
}
builder.endObject();
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
FeatureImportance that = (FeatureImportance) object;
return Objects.equals(featureName, that.featureName)
&& Objects.equals(importance, that.importance)
&& Objects.equals(classImportance, that.classImportance);
}
@Override
public int hashCode() {
return Objects.hash(featureName, importance, classImportance);
}
}

View File

@ -0,0 +1,116 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.results;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
public class TopClassEntry implements ToXContentObject {
public static final ParseField CLASS_NAME = new ParseField("class_name");
public static final ParseField CLASS_PROBABILITY = new ParseField("class_probability");
public static final ParseField CLASS_SCORE = new ParseField("class_score");
public static final String NAME = "top_class";
private static final ConstructingObjectParser<TopClassEntry, Void> PARSER =
new ConstructingObjectParser<>(NAME, true, a -> new TopClassEntry(a[0], (Double) a[1], (Double) a[2]));
static {
PARSER.declareField(constructorArg(), (p, n) -> {
Object o;
XContentParser.Token token = p.currentToken();
if (token == XContentParser.Token.VALUE_STRING) {
o = p.text();
} else if (token == XContentParser.Token.VALUE_BOOLEAN) {
o = p.booleanValue();
} else if (token == XContentParser.Token.VALUE_NUMBER) {
o = p.doubleValue();
} else {
throw new XContentParseException(p.getTokenLocation(),
"[" + NAME + "] failed to parse field [" + CLASS_NAME + "] value [" + token
+ "] is not a string, boolean or number");
}
return o;
}, CLASS_NAME, ObjectParser.ValueType.VALUE);
PARSER.declareDouble(constructorArg(), CLASS_PROBABILITY);
PARSER.declareDouble(constructorArg(), CLASS_SCORE);
}
public static TopClassEntry fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
}
private final Object classification;
private final double probability;
private final double score;
public TopClassEntry(Object classification, double probability, double score) {
this.classification = Objects.requireNonNull(classification);
this.probability = probability;
this.score = score;
}
public Object getClassification() {
return classification;
}
public double getProbability() {
return probability;
}
public double getScore() {
return score;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(CLASS_NAME.getPreferredName(), classification);
builder.field(CLASS_PROBABILITY.getPreferredName(), probability);
builder.field(CLASS_SCORE.getPreferredName(), score);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
TopClassEntry that = (TopClassEntry) object;
return Objects.equals(classification, that.classification) && probability == that.probability && score == that.score;
}
@Override
public int hashCode() {
return Objects.hash(classification, probability, score);
}
}

View File

@ -688,6 +688,7 @@ public class RestHighLevelClientTests extends ESTestCase {
// Explicitly check for metrics from the analytics module because they aren't in InternalAggregationTestCase
assertTrue(namedXContents.removeIf(e -> e.name.getPreferredName().equals("string_stats")));
assertTrue(namedXContents.removeIf(e -> e.name.getPreferredName().equals("top_metrics")));
assertTrue(namedXContents.removeIf(e -> e.name.getPreferredName().equals("inference")));
assertEquals(expectedInternalAggregations + expectedSuggestions, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>();

View File

@ -0,0 +1,127 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.analytics;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.client.ESRestHighLevelClientTestCase;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.indices.CreateIndexRequest;
import org.elasticsearch.client.ml.PutTrainedModelRequest;
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
import org.elasticsearch.client.ml.inference.TrainedModelInput;
import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.search.aggregations.bucket.terms.ParsedTerms;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
public class InferenceAggIT extends ESRestHighLevelClientTestCase {
public void testInferenceAgg() throws IOException {
// create a very simple decision tree with a root node and 2 leaves
List<String> featureNames = Collections.singletonList("cost");
Tree.Builder builder = Tree.builder();
builder.setFeatureNames(featureNames);
TreeNode.Builder root = builder.addJunction(0, 0, true, 1.0);
int leftChild = root.getLeftChild();
int rightChild = root.getRightChild();
builder.addLeaf(leftChild, 10.0);
builder.addLeaf(rightChild, 20.0);
final String modelId = "simple_regression";
putTrainedModel(modelId, featureNames, builder.build());
final String index = "inference-test-data";
indexData(index);
TermsAggregationBuilder termsAgg = new TermsAggregationBuilder("fruit_type").field("fruit");
AvgAggregationBuilder avgAgg = new AvgAggregationBuilder("avg_cost").field("cost");
termsAgg.subAggregation(avgAgg);
Map<String, String> bucketPaths = new HashMap<>();
bucketPaths.put("cost", "avg_cost");
InferencePipelineAggregationBuilder inferenceAgg = new InferencePipelineAggregationBuilder("infer", modelId, bucketPaths);
termsAgg.subAggregation(inferenceAgg);
SearchRequest search = new SearchRequest(index);
search.source().aggregation(termsAgg);
SearchResponse response = highLevelClient().search(search, RequestOptions.DEFAULT);
ParsedTerms terms = response.getAggregations().get("fruit_type");
List<? extends Terms.Bucket> buckets = terms.getBuckets();
{
assertThat(buckets.get(0).getKey(), equalTo("apple"));
ParsedInference inference = buckets.get(0).getAggregations().get("infer");
assertThat((Double) inference.getValue(), closeTo(20.0, 0.01));
assertNull(inference.getWarning());
assertNull(inference.getFeatureImportance());
assertNull(inference.getTopClasses());
}
{
assertThat(buckets.get(1).getKey(), equalTo("banana"));
ParsedInference inference = buckets.get(1).getAggregations().get("infer");
assertThat((Double) inference.getValue(), closeTo(10.0, 0.01));
assertNull(inference.getWarning());
assertNull(inference.getFeatureImportance());
assertNull(inference.getTopClasses());
}
}
private void putTrainedModel(String modelId, List<String> inputFields, Tree tree) throws IOException {
TrainedModelDefinition definition = new TrainedModelDefinition.Builder().setTrainedModel(tree).build();
TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
.setDefinition(definition)
.setModelId(modelId)
.setInferenceConfig(new RegressionConfig())
.setInput(new TrainedModelInput(inputFields))
.setDescription("test model")
.build();
highLevelClient().machineLearning().putTrainedModel(new PutTrainedModelRequest(trainedModelConfig), RequestOptions.DEFAULT);
}
private void indexData(String index) throws IOException {
CreateIndexRequest create = new CreateIndexRequest(index);
create.mapping("{\"properties\": {\"fruit\": {\"type\": \"keyword\"}," +
"\"cost\": {\"type\": \"double\"}}}", XContentType.JSON);
highLevelClient().indices().create(create, RequestOptions.DEFAULT);
BulkRequest bulk = new BulkRequest(index).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
bulk.add(new IndexRequest().source(XContentType.JSON, "fruit", "apple", "cost", "1.2"));
bulk.add(new IndexRequest().source(XContentType.JSON, "fruit", "banana", "cost", "0.8"));
bulk.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
}
}

View File

@ -54,5 +54,4 @@ public class TrainedModelInputTests extends AbstractXContentTestCase<TrainedMode
protected TrainedModelInput createTestInstance() {
return createRandomInput();
}
}

View File

@ -0,0 +1,59 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.results;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class FeatureImportanceTests extends AbstractXContentTestCase<FeatureImportance> {
@Override
protected FeatureImportance createTestInstance() {
return new FeatureImportance(
randomAlphaOfLength(10),
randomDoubleBetween(-10.0, 10.0, false),
randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomLongBetween(2, 10))
.collect(Collectors.toMap(Function.identity(), (k) -> randomDoubleBetween(-10, 10, false))));
}
@Override
protected FeatureImportance doParseInstance(XContentParser parser) throws IOException {
return FeatureImportance.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return field -> field.equals(FeatureImportance.CLASS_IMPORTANCE);
}
}

View File

@ -0,0 +1,50 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.results;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class TopClassEntryTests extends AbstractXContentTestCase<TopClassEntry> {
@Override
protected TopClassEntry createTestInstance() {
Object classification;
if (randomBoolean()) {
classification = randomAlphaOfLength(10);
} else if (randomBoolean()) {
classification = randomBoolean();
} else {
classification = randomDouble();
}
return new TopClassEntry(classification, randomDouble(), randomDouble());
}
@Override
protected TopClassEntry doParseInstance(XContentParser parser) throws IOException {
return TopClassEntry.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
}

View File

@ -62,6 +62,7 @@ This page lists all the available aggregations with their corresponding `Aggrega
| Pipeline on | PipelineAggregationBuilder Class | Method in PipelineAggregatorBuilders
| {ref}/search-aggregations-pipeline-avg-bucket-aggregation.html[Avg Bucket] | {agg-ref}/pipeline/bucketmetrics/avg/AvgBucketPipelineAggregationBuilder.html[AvgBucketPipelineAggregationBuilder] | {agg-ref}/pipeline/PipelineAggregatorBuilders.html#avgBucket-java.lang.String-java.lang.String-[PipelineAggregatorBuilders.avgBucket()]
| {ref}/search-aggregations-pipeline-derivative-aggregation.html[Derivative] | {agg-ref}/pipeline/derivative/DerivativePipelineAggregationBuilder.html[DerivativePipelineAggregationBuilder] | {agg-ref}/pipeline/PipelineAggregatorBuilders.html#derivative-java.lang.String-java.lang.String-[PipelineAggregatorBuilders.derivative()]
| {ref}/search-aggregations-pipeline-inference-bucket-aggregation.html[Inference] | {javadoc-client}/analytics/InferencePipelineAggregationBuilder.html[InferencePipelineAggregationBuilder] | None
| {ref}/search-aggregations-pipeline-max-bucket-aggregation.html[Max Bucket] | {agg-ref}/pipeline/bucketmetrics/max/MaxBucketPipelineAggregationBuilder.html[MaxBucketPipelineAggregationBuilder] | {agg-ref}/pipeline/PipelineAggregatorBuilders.html#maxBucket-java.lang.String-java.lang.String-[PipelineAggregatorBuilders.maxBucket()]
| {ref}/search-aggregations-pipeline-min-bucket-aggregation.html[Min Bucket] | {agg-ref}/pipeline/bucketmetrics/min/MinBucketPipelineAggregationBuilder.html[MinBucketPipelineAggregationBuilder] | {agg-ref}/pipeline/PipelineAggregatorBuilders.html#minBucket-java.lang.String-java.lang.String-[PipelineAggregatorBuilders.minBucket()]
| {ref}/search-aggregations-pipeline-sum-bucket-aggregation.html[Sum Bucket] | {agg-ref}/pipeline/bucketmetrics/sum/SumBucketPipelineAggregationBuilder.html[SumBucketPipelineAggregationBuilder] | {agg-ref}/pipeline/PipelineAggregatorBuilders.html#sumBucket-java.lang.String-java.lang.String-[PipelineAggregatorBuilders.sumBucket()]