This adds the necessary named XContent classes to the HLRC for the lang ident model. This is so the HLRC can call `GET _ml/inference/lang_ident_model_1?include_definition=true` without XContent parsing errors. The constructors are package private as since this classes are used exclusively within the pre-packaged model (and require the specific weights, etc. to be of any use).
This commit is contained in:
parent
3e014d39c2
commit
cc0e64572a
|
@ -18,12 +18,14 @@
|
|||
*/
|
||||
package org.elasticsearch.client.ml.inference;
|
||||
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
|
||||
|
@ -49,10 +51,15 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
|
|||
TargetMeanEncoding::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(FrequencyEncoding.NAME),
|
||||
FrequencyEncoding::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(CustomWordEmbedding.NAME),
|
||||
CustomWordEmbedding::fromXContent));
|
||||
|
||||
// Model
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Tree.NAME), Tree::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Ensemble.NAME), Ensemble::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class,
|
||||
new ParseField(LangIdentNeuralNetwork.NAME),
|
||||
LangIdentNeuralNetwork::fromXContent));
|
||||
|
||||
// Aggregating output
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class,
|
||||
|
|
|
@ -0,0 +1,166 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.preprocessing;
|
||||
|
||||
import org.elasticsearch.common.CheckedFunction;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* This is a pre-processor that embeds text into a numerical vector.
|
||||
*
|
||||
* It calculates a set of features based on script type, ngram hashes, and most common script values.
|
||||
*
|
||||
* The features are then concatenated with specific quantization scales and weights into a vector of length 80.
|
||||
*
|
||||
* This is a fork and a port of: https://github.com/google/cld3/blob/06f695f1c8ee530104416aab5dcf2d6a1414a56a/src/embedding_network.cc
|
||||
*/
|
||||
public class CustomWordEmbedding implements PreProcessor {
|
||||
|
||||
public static final String NAME = "custom_word_embedding";
|
||||
static final ParseField FIELD = new ParseField("field");
|
||||
static final ParseField DEST_FIELD = new ParseField("dest_field");
|
||||
static final ParseField EMBEDDING_WEIGHTS = new ParseField("embedding_weights");
|
||||
static final ParseField EMBEDDING_QUANT_SCALES = new ParseField("embedding_quant_scales");
|
||||
|
||||
public static final ConstructingObjectParser<CustomWordEmbedding, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME,
|
||||
true,
|
||||
a -> new CustomWordEmbedding((short[][])a[0], (byte[][])a[1], (String)a[2], (String)a[3]));
|
||||
static {
|
||||
PARSER.declareField(ConstructingObjectParser.constructorArg(),
|
||||
(p, c) -> {
|
||||
List<List<Short>> listOfListOfShorts = parseArrays(EMBEDDING_QUANT_SCALES.getPreferredName(),
|
||||
XContentParser::shortValue,
|
||||
p);
|
||||
short[][] primitiveShorts = new short[listOfListOfShorts.size()][];
|
||||
int i = 0;
|
||||
for (List<Short> shorts : listOfListOfShorts) {
|
||||
short[] innerShorts = new short[shorts.size()];
|
||||
for (int j = 0; j < shorts.size(); j++) {
|
||||
innerShorts[j] = shorts.get(j);
|
||||
}
|
||||
primitiveShorts[i++] = innerShorts;
|
||||
}
|
||||
return primitiveShorts;
|
||||
},
|
||||
EMBEDDING_QUANT_SCALES,
|
||||
ObjectParser.ValueType.VALUE_ARRAY);
|
||||
PARSER.declareField(ConstructingObjectParser.constructorArg(),
|
||||
(p, c) -> {
|
||||
List<byte[]> values = new ArrayList<>();
|
||||
while(p.nextToken() != XContentParser.Token.END_ARRAY) {
|
||||
values.add(p.binaryValue());
|
||||
}
|
||||
byte[][] primitiveBytes = new byte[values.size()][];
|
||||
int i = 0;
|
||||
for (byte[] bytes : values) {
|
||||
primitiveBytes[i++] = bytes;
|
||||
}
|
||||
return primitiveBytes;
|
||||
},
|
||||
EMBEDDING_WEIGHTS,
|
||||
ObjectParser.ValueType.VALUE_ARRAY);
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEST_FIELD);
|
||||
}
|
||||
|
||||
private static <T> List<List<T>> parseArrays(String fieldName,
|
||||
CheckedFunction<XContentParser, T, IOException> fromParser,
|
||||
XContentParser p) throws IOException {
|
||||
if (p.currentToken() != XContentParser.Token.START_ARRAY) {
|
||||
throw new IllegalArgumentException("unexpected token [" + p.currentToken() + "] for [" + fieldName + "]");
|
||||
}
|
||||
List<List<T>> values = new ArrayList<>();
|
||||
while(p.nextToken() != XContentParser.Token.END_ARRAY) {
|
||||
if (p.currentToken() != XContentParser.Token.START_ARRAY) {
|
||||
throw new IllegalArgumentException("unexpected token [" + p.currentToken() + "] for [" + fieldName + "]");
|
||||
}
|
||||
List<T> innerList = new ArrayList<>();
|
||||
while(p.nextToken() != XContentParser.Token.END_ARRAY) {
|
||||
if(p.currentToken().isValue() == false) {
|
||||
throw new IllegalStateException("expected non-null value but got [" + p.currentToken() + "] " +
|
||||
"for [" + fieldName + "]");
|
||||
}
|
||||
innerList.add(fromParser.apply(p));
|
||||
}
|
||||
values.add(innerList);
|
||||
}
|
||||
return values;
|
||||
}
|
||||
|
||||
public static CustomWordEmbedding fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final short[][] embeddingsQuantScales;
|
||||
private final byte[][] embeddingsWeights;
|
||||
private final String fieldName;
|
||||
private final String destField;
|
||||
|
||||
CustomWordEmbedding(short[][] embeddingsQuantScales, byte[][] embeddingsWeights, String fieldName, String destField) {
|
||||
this.embeddingsQuantScales = embeddingsQuantScales;
|
||||
this.embeddingsWeights = embeddingsWeights;
|
||||
this.fieldName = fieldName;
|
||||
this.destField = destField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FIELD.getPreferredName(), fieldName);
|
||||
builder.field(DEST_FIELD.getPreferredName(), destField);
|
||||
builder.field(EMBEDDING_QUANT_SCALES.getPreferredName(), embeddingsQuantScales);
|
||||
builder.field(EMBEDDING_WEIGHTS.getPreferredName(), embeddingsWeights);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
CustomWordEmbedding that = (CustomWordEmbedding) o;
|
||||
return Objects.equals(fieldName, that.fieldName)
|
||||
&& Objects.equals(destField, that.destField)
|
||||
&& Arrays.deepEquals(embeddingsWeights, that.embeddingsWeights)
|
||||
&& Arrays.deepEquals(embeddingsQuantScales, that.embeddingsQuantScales);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(fieldName, destField, Arrays.deepHashCode(embeddingsQuantScales), Arrays.deepHashCode(embeddingsWeights));
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,108 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.trainedmodel.langident;
|
||||
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
|
||||
/**
|
||||
* Shallow, fully connected, feed forward NN modeled after and ported from https://github.com/google/cld3
|
||||
*/
|
||||
public class LangIdentNeuralNetwork implements TrainedModel {
|
||||
|
||||
public static final String NAME = "lang_ident_neural_network";
|
||||
public static final ParseField EMBEDDED_VECTOR_FEATURE_NAME = new ParseField("embedded_vector_feature_name");
|
||||
public static final ParseField HIDDEN_LAYER = new ParseField("hidden_layer");
|
||||
public static final ParseField SOFTMAX_LAYER = new ParseField("softmax_layer");
|
||||
public static final ConstructingObjectParser<LangIdentNeuralNetwork, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME,
|
||||
true,
|
||||
a -> new LangIdentNeuralNetwork((String) a[0],
|
||||
(LangNetLayer) a[1],
|
||||
(LangNetLayer) a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), EMBEDDED_VECTOR_FEATURE_NAME);
|
||||
PARSER.declareObject(constructorArg(), LangNetLayer.PARSER::apply, HIDDEN_LAYER);
|
||||
PARSER.declareObject(constructorArg(), LangNetLayer.PARSER::apply, SOFTMAX_LAYER);
|
||||
}
|
||||
|
||||
public static LangIdentNeuralNetwork fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final LangNetLayer hiddenLayer;
|
||||
private final LangNetLayer softmaxLayer;
|
||||
private final String embeddedVectorFeatureName;
|
||||
|
||||
LangIdentNeuralNetwork(String embeddedVectorFeatureName,
|
||||
LangNetLayer hiddenLayer,
|
||||
LangNetLayer softmaxLayer) {
|
||||
this.embeddedVectorFeatureName = embeddedVectorFeatureName;
|
||||
this.hiddenLayer = hiddenLayer;
|
||||
this.softmaxLayer = softmaxLayer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getFeatureNames() {
|
||||
return Collections.singletonList(embeddedVectorFeatureName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(EMBEDDED_VECTOR_FEATURE_NAME.getPreferredName(), embeddedVectorFeatureName);
|
||||
builder.field(HIDDEN_LAYER.getPreferredName(), hiddenLayer);
|
||||
builder.field(SOFTMAX_LAYER.getPreferredName(), softmaxLayer);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
LangIdentNeuralNetwork that = (LangIdentNeuralNetwork) o;
|
||||
return Objects.equals(embeddedVectorFeatureName, that.embeddedVectorFeatureName)
|
||||
&& Objects.equals(hiddenLayer, that.hiddenLayer)
|
||||
&& Objects.equals(softmaxLayer, that.softmaxLayer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(embeddedVectorFeatureName, hiddenLayer, softmaxLayer);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,123 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.trainedmodel.langident;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
|
||||
/**
|
||||
* Represents a single layer in the compressed Lang Net
|
||||
*/
|
||||
public class LangNetLayer implements ToXContentObject {
|
||||
|
||||
public static final ParseField NAME = new ParseField("lang_net_layer");
|
||||
|
||||
private static final ParseField NUM_ROWS = new ParseField("num_rows");
|
||||
private static final ParseField NUM_COLS = new ParseField("num_cols");
|
||||
private static final ParseField WEIGHTS = new ParseField("weights");
|
||||
private static final ParseField BIAS = new ParseField("bias");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<LangNetLayer, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
true,
|
||||
a -> new LangNetLayer(
|
||||
(List<Double>) a[0],
|
||||
(int) a[1],
|
||||
(int) a[2],
|
||||
(List<Double>) a[3]));
|
||||
|
||||
static {
|
||||
PARSER.declareDoubleArray(constructorArg(), WEIGHTS);
|
||||
PARSER.declareInt(constructorArg(), NUM_COLS);
|
||||
PARSER.declareInt(constructorArg(), NUM_ROWS);
|
||||
PARSER.declareDoubleArray(constructorArg(), BIAS);
|
||||
}
|
||||
|
||||
private final double[] weights;
|
||||
private final int weightRows;
|
||||
private final int weightCols;
|
||||
private final double[] bias;
|
||||
|
||||
private LangNetLayer(List<Double> weights, int numCols, int numRows, List<Double> bias) {
|
||||
this(weights.stream().mapToDouble(Double::doubleValue).toArray(),
|
||||
numCols,
|
||||
numRows,
|
||||
bias.stream().mapToDouble(Double::doubleValue).toArray());
|
||||
}
|
||||
|
||||
LangNetLayer(double[] weights, int numCols, int numRows, double[] bias) {
|
||||
this.weights = weights;
|
||||
this.weightCols = numCols;
|
||||
this.weightRows = numRows;
|
||||
this.bias = bias;
|
||||
}
|
||||
|
||||
double[] getWeights() {
|
||||
return weights;
|
||||
}
|
||||
|
||||
int getWeightRows() {
|
||||
return weightRows;
|
||||
}
|
||||
|
||||
int getWeightCols() {
|
||||
return weightCols;
|
||||
}
|
||||
|
||||
double[] getBias() {
|
||||
return bias;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(NUM_COLS.getPreferredName(), weightCols);
|
||||
builder.field(NUM_ROWS.getPreferredName(), weightRows);
|
||||
builder.field(WEIGHTS.getPreferredName(), weights);
|
||||
builder.field(BIAS.getPreferredName(), bias);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
LangNetLayer that = (LangNetLayer) o;
|
||||
return Arrays.equals(weights, that.weights)
|
||||
&& Arrays.equals(bias, that.bias)
|
||||
&& Objects.equals(weightCols, that.weightCols)
|
||||
&& Objects.equals(weightRows, that.weightRows);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(Arrays.hashCode(weights), Arrays.hashCode(bias), weightCols, weightRows);
|
||||
}
|
||||
}
|
|
@ -151,6 +151,7 @@ import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
|
|||
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelStats;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
|
||||
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
|
||||
import org.elasticsearch.client.ml.job.config.DataDescription;
|
||||
import org.elasticsearch.client.ml.job.config.Detector;
|
||||
|
@ -201,6 +202,7 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo;
|
|||
import static org.hamcrest.Matchers.hasItem;
|
||||
import static org.hamcrest.Matchers.hasItems;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.lessThan;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
|
@ -2308,6 +2310,21 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(0));
|
||||
}
|
||||
|
||||
public void testGetPrepackagedModels() throws Exception {
|
||||
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
||||
|
||||
GetTrainedModelsResponse getTrainedModelsResponse = execute(
|
||||
new GetTrainedModelsRequest("lang_ident_model_1").setIncludeDefinition(true),
|
||||
machineLearningClient::getTrainedModels,
|
||||
machineLearningClient::getTrainedModelsAsync);
|
||||
|
||||
assertThat(getTrainedModelsResponse.getCount(), equalTo(1L));
|
||||
assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(1));
|
||||
assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo("lang_ident_model_1"));
|
||||
assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getDefinition().getTrainedModel(),
|
||||
instanceOf(LangIdentNeuralNetwork.class));
|
||||
}
|
||||
|
||||
public void testPutFilter() throws Exception {
|
||||
String filterId = "filter-job-test";
|
||||
MlFilter mlFilter = MlFilter.builder(filterId)
|
||||
|
|
|
@ -68,6 +68,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Binar
|
|||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncoding;
|
||||
|
@ -75,6 +76,7 @@ import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
|
|||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
|
||||
import org.elasticsearch.client.transform.transforms.SyncConfig;
|
||||
import org.elasticsearch.client.transform.transforms.TimeSyncConfig;
|
||||
|
@ -688,7 +690,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
|||
|
||||
public void testProvidedNamedXContents() {
|
||||
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
|
||||
assertEquals(55, namedXContents.size());
|
||||
assertEquals(57, namedXContents.size());
|
||||
Map<Class<?>, Integer> categories = new HashMap<>();
|
||||
List<String> names = new ArrayList<>();
|
||||
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
|
||||
|
@ -760,10 +762,10 @@ public class RestHighLevelClientTests extends ESTestCase {
|
|||
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
|
||||
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
|
||||
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
|
||||
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
|
||||
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME));
|
||||
assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
|
||||
assertThat(names, hasItems(Tree.NAME, Ensemble.NAME));
|
||||
assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
|
||||
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME));
|
||||
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
|
||||
assertThat(names, hasItems(Tree.NAME, Ensemble.NAME, LangIdentNeuralNetwork.NAME));
|
||||
assertEquals(Integer.valueOf(3),
|
||||
categories.get(org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator.class));
|
||||
assertThat(names, hasItems(WeightedMode.NAME, WeightedSum.NAME, LogisticRegression.NAME));
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.preprocessing;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
|
||||
public class CustomWordEmbeddingTests extends AbstractXContentTestCase<CustomWordEmbedding> {
|
||||
|
||||
@Override
|
||||
protected CustomWordEmbedding doParseInstance(XContentParser parser) throws IOException {
|
||||
return CustomWordEmbedding.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected CustomWordEmbedding createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static CustomWordEmbedding createRandom() {
|
||||
int quantileSize = randomIntBetween(1, 10);
|
||||
int internalQuantSize = randomIntBetween(1, 10);
|
||||
short[][] quantiles = new short[quantileSize][internalQuantSize];
|
||||
for (int i = 0; i < quantileSize; i++) {
|
||||
for (int j = 0; j < internalQuantSize; j++) {
|
||||
quantiles[i][j] = randomShort();
|
||||
}
|
||||
}
|
||||
int weightsSize = randomIntBetween(1, 10);
|
||||
int internalWeightsSize = randomIntBetween(1, 10);
|
||||
byte[][] weights = new byte[weightsSize][internalWeightsSize];
|
||||
for (int i = 0; i < weightsSize; i++) {
|
||||
for (int j = 0; j < internalWeightsSize; j++) {
|
||||
weights[i][j] = randomByte();
|
||||
}
|
||||
}
|
||||
return new CustomWordEmbedding(quantiles, weights, randomAlphaOfLength(10), randomAlphaOfLength(10));
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.trainedmodel.langident;
|
||||
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
|
||||
public class LangIdentNeuralNetworkTests extends AbstractXContentTestCase<LangIdentNeuralNetwork> {
|
||||
|
||||
@Override
|
||||
protected LangIdentNeuralNetwork doParseInstance(XContentParser parser) throws IOException {
|
||||
return LangIdentNeuralNetwork.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected LangIdentNeuralNetwork createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static LangIdentNeuralNetwork createRandom() {
|
||||
return new LangIdentNeuralNetwork(randomAlphaOfLength(10),
|
||||
LangNetLayerTests.createRandom(),
|
||||
LangNetLayerTests.createRandom());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.trainedmodel.langident;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
||||
public class LangNetLayerTests extends AbstractXContentTestCase<LangNetLayer> {
|
||||
|
||||
@Override
|
||||
protected LangNetLayer doParseInstance(XContentParser parser) throws IOException {
|
||||
return LangNetLayer.PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected LangNetLayer createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static LangNetLayer createRandom() {
|
||||
int numWeights = randomIntBetween(1, 1000);
|
||||
return new LangNetLayer(
|
||||
Stream.generate(ESTestCase::randomDouble).limit(numWeights).mapToDouble(Double::doubleValue).toArray(),
|
||||
numWeights,
|
||||
1,
|
||||
Stream.generate(ESTestCase::randomDouble).limit(numWeights).mapToDouble(Double::doubleValue).toArray());
|
||||
}
|
||||
|
||||
}
|
|
@ -198,6 +198,16 @@ public class TrainedModelIT extends ESRestTestCase {
|
|||
assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404));
|
||||
}
|
||||
|
||||
public void testGetPrePackagedModels() throws IOException {
|
||||
Response getModel = client().performRequest(new Request("GET",
|
||||
MachineLearning.BASE_PATH + "inference/lang_ident_model_1?human=true&include_model_definition=true"));
|
||||
|
||||
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
||||
String response = EntityUtils.toString(getModel.getEntity());
|
||||
assertThat(response, containsString("lang_ident_model_1"));
|
||||
assertThat(response, containsString("\"definition\""));
|
||||
}
|
||||
|
||||
private static String buildRegressionModel(String modelId) throws IOException {
|
||||
try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
|
||||
TrainedModelConfig.builder()
|
||||
|
|
Loading…
Reference in New Issue