[ML][Inference][HLRC] Add necessary lang ident classes (#50705) (#50794)

This adds the necessary named XContent classes to the HLRC for the lang ident model. This is so the HLRC can call `GET _ml/inference/lang_ident_model_1?include_definition=true` without XContent parsing errors.

The constructors are package private as since this classes are used exclusively within the pre-packaged model (and require the specific weights, etc. to be of any use).
This commit is contained in:
Benjamin Trent 2020-01-09 10:33:38 -05:00 committed by GitHub
parent 3e014d39c2
commit cc0e64572a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 614 additions and 5 deletions

View File

@ -18,12 +18,14 @@
*/
package org.elasticsearch.client.ml.inference;
import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
import org.elasticsearch.client.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
@ -49,10 +51,15 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
TargetMeanEncoding::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(FrequencyEncoding.NAME),
FrequencyEncoding::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(CustomWordEmbedding.NAME),
CustomWordEmbedding::fromXContent));
// Model
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Tree.NAME), Tree::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Ensemble.NAME), Ensemble::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class,
new ParseField(LangIdentNeuralNetwork.NAME),
LangIdentNeuralNetwork::fromXContent));
// Aggregating output
namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class,

View File

@ -0,0 +1,166 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.preprocessing;
import org.elasticsearch.common.CheckedFunction;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
/**
* This is a pre-processor that embeds text into a numerical vector.
*
* It calculates a set of features based on script type, ngram hashes, and most common script values.
*
* The features are then concatenated with specific quantization scales and weights into a vector of length 80.
*
* This is a fork and a port of: https://github.com/google/cld3/blob/06f695f1c8ee530104416aab5dcf2d6a1414a56a/src/embedding_network.cc
*/
public class CustomWordEmbedding implements PreProcessor {
public static final String NAME = "custom_word_embedding";
static final ParseField FIELD = new ParseField("field");
static final ParseField DEST_FIELD = new ParseField("dest_field");
static final ParseField EMBEDDING_WEIGHTS = new ParseField("embedding_weights");
static final ParseField EMBEDDING_QUANT_SCALES = new ParseField("embedding_quant_scales");
public static final ConstructingObjectParser<CustomWordEmbedding, Void> PARSER = new ConstructingObjectParser<>(
NAME,
true,
a -> new CustomWordEmbedding((short[][])a[0], (byte[][])a[1], (String)a[2], (String)a[3]));
static {
PARSER.declareField(ConstructingObjectParser.constructorArg(),
(p, c) -> {
List<List<Short>> listOfListOfShorts = parseArrays(EMBEDDING_QUANT_SCALES.getPreferredName(),
XContentParser::shortValue,
p);
short[][] primitiveShorts = new short[listOfListOfShorts.size()][];
int i = 0;
for (List<Short> shorts : listOfListOfShorts) {
short[] innerShorts = new short[shorts.size()];
for (int j = 0; j < shorts.size(); j++) {
innerShorts[j] = shorts.get(j);
}
primitiveShorts[i++] = innerShorts;
}
return primitiveShorts;
},
EMBEDDING_QUANT_SCALES,
ObjectParser.ValueType.VALUE_ARRAY);
PARSER.declareField(ConstructingObjectParser.constructorArg(),
(p, c) -> {
List<byte[]> values = new ArrayList<>();
while(p.nextToken() != XContentParser.Token.END_ARRAY) {
values.add(p.binaryValue());
}
byte[][] primitiveBytes = new byte[values.size()][];
int i = 0;
for (byte[] bytes : values) {
primitiveBytes[i++] = bytes;
}
return primitiveBytes;
},
EMBEDDING_WEIGHTS,
ObjectParser.ValueType.VALUE_ARRAY);
PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEST_FIELD);
}
private static <T> List<List<T>> parseArrays(String fieldName,
CheckedFunction<XContentParser, T, IOException> fromParser,
XContentParser p) throws IOException {
if (p.currentToken() != XContentParser.Token.START_ARRAY) {
throw new IllegalArgumentException("unexpected token [" + p.currentToken() + "] for [" + fieldName + "]");
}
List<List<T>> values = new ArrayList<>();
while(p.nextToken() != XContentParser.Token.END_ARRAY) {
if (p.currentToken() != XContentParser.Token.START_ARRAY) {
throw new IllegalArgumentException("unexpected token [" + p.currentToken() + "] for [" + fieldName + "]");
}
List<T> innerList = new ArrayList<>();
while(p.nextToken() != XContentParser.Token.END_ARRAY) {
if(p.currentToken().isValue() == false) {
throw new IllegalStateException("expected non-null value but got [" + p.currentToken() + "] " +
"for [" + fieldName + "]");
}
innerList.add(fromParser.apply(p));
}
values.add(innerList);
}
return values;
}
public static CustomWordEmbedding fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final short[][] embeddingsQuantScales;
private final byte[][] embeddingsWeights;
private final String fieldName;
private final String destField;
CustomWordEmbedding(short[][] embeddingsQuantScales, byte[][] embeddingsWeights, String fieldName, String destField) {
this.embeddingsQuantScales = embeddingsQuantScales;
this.embeddingsWeights = embeddingsWeights;
this.fieldName = fieldName;
this.destField = destField;
}
@Override
public String getName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(FIELD.getPreferredName(), fieldName);
builder.field(DEST_FIELD.getPreferredName(), destField);
builder.field(EMBEDDING_QUANT_SCALES.getPreferredName(), embeddingsQuantScales);
builder.field(EMBEDDING_WEIGHTS.getPreferredName(), embeddingsWeights);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
CustomWordEmbedding that = (CustomWordEmbedding) o;
return Objects.equals(fieldName, that.fieldName)
&& Objects.equals(destField, that.destField)
&& Arrays.deepEquals(embeddingsWeights, that.embeddingsWeights)
&& Arrays.deepEquals(embeddingsQuantScales, that.embeddingsQuantScales);
}
@Override
public int hashCode() {
return Objects.hash(fieldName, destField, Arrays.deepHashCode(embeddingsQuantScales), Arrays.deepHashCode(embeddingsWeights));
}
}

View File

@ -0,0 +1,108 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel.langident;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
/**
* Shallow, fully connected, feed forward NN modeled after and ported from https://github.com/google/cld3
*/
public class LangIdentNeuralNetwork implements TrainedModel {
public static final String NAME = "lang_ident_neural_network";
public static final ParseField EMBEDDED_VECTOR_FEATURE_NAME = new ParseField("embedded_vector_feature_name");
public static final ParseField HIDDEN_LAYER = new ParseField("hidden_layer");
public static final ParseField SOFTMAX_LAYER = new ParseField("softmax_layer");
public static final ConstructingObjectParser<LangIdentNeuralNetwork, Void> PARSER = new ConstructingObjectParser<>(
NAME,
true,
a -> new LangIdentNeuralNetwork((String) a[0],
(LangNetLayer) a[1],
(LangNetLayer) a[2]));
static {
PARSER.declareString(constructorArg(), EMBEDDED_VECTOR_FEATURE_NAME);
PARSER.declareObject(constructorArg(), LangNetLayer.PARSER::apply, HIDDEN_LAYER);
PARSER.declareObject(constructorArg(), LangNetLayer.PARSER::apply, SOFTMAX_LAYER);
}
public static LangIdentNeuralNetwork fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final LangNetLayer hiddenLayer;
private final LangNetLayer softmaxLayer;
private final String embeddedVectorFeatureName;
LangIdentNeuralNetwork(String embeddedVectorFeatureName,
LangNetLayer hiddenLayer,
LangNetLayer softmaxLayer) {
this.embeddedVectorFeatureName = embeddedVectorFeatureName;
this.hiddenLayer = hiddenLayer;
this.softmaxLayer = softmaxLayer;
}
@Override
public List<String> getFeatureNames() {
return Collections.singletonList(embeddedVectorFeatureName);
}
@Override
public String getName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(EMBEDDED_VECTOR_FEATURE_NAME.getPreferredName(), embeddedVectorFeatureName);
builder.field(HIDDEN_LAYER.getPreferredName(), hiddenLayer);
builder.field(SOFTMAX_LAYER.getPreferredName(), softmaxLayer);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
LangIdentNeuralNetwork that = (LangIdentNeuralNetwork) o;
return Objects.equals(embeddedVectorFeatureName, that.embeddedVectorFeatureName)
&& Objects.equals(hiddenLayer, that.hiddenLayer)
&& Objects.equals(softmaxLayer, that.softmaxLayer);
}
@Override
public int hashCode() {
return Objects.hash(embeddedVectorFeatureName, hiddenLayer, softmaxLayer);
}
}

View File

@ -0,0 +1,123 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel.langident;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
/**
* Represents a single layer in the compressed Lang Net
*/
public class LangNetLayer implements ToXContentObject {
public static final ParseField NAME = new ParseField("lang_net_layer");
private static final ParseField NUM_ROWS = new ParseField("num_rows");
private static final ParseField NUM_COLS = new ParseField("num_cols");
private static final ParseField WEIGHTS = new ParseField("weights");
private static final ParseField BIAS = new ParseField("bias");
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<LangNetLayer, Void> PARSER = new ConstructingObjectParser<>(
NAME.getPreferredName(),
true,
a -> new LangNetLayer(
(List<Double>) a[0],
(int) a[1],
(int) a[2],
(List<Double>) a[3]));
static {
PARSER.declareDoubleArray(constructorArg(), WEIGHTS);
PARSER.declareInt(constructorArg(), NUM_COLS);
PARSER.declareInt(constructorArg(), NUM_ROWS);
PARSER.declareDoubleArray(constructorArg(), BIAS);
}
private final double[] weights;
private final int weightRows;
private final int weightCols;
private final double[] bias;
private LangNetLayer(List<Double> weights, int numCols, int numRows, List<Double> bias) {
this(weights.stream().mapToDouble(Double::doubleValue).toArray(),
numCols,
numRows,
bias.stream().mapToDouble(Double::doubleValue).toArray());
}
LangNetLayer(double[] weights, int numCols, int numRows, double[] bias) {
this.weights = weights;
this.weightCols = numCols;
this.weightRows = numRows;
this.bias = bias;
}
double[] getWeights() {
return weights;
}
int getWeightRows() {
return weightRows;
}
int getWeightCols() {
return weightCols;
}
double[] getBias() {
return bias;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(NUM_COLS.getPreferredName(), weightCols);
builder.field(NUM_ROWS.getPreferredName(), weightRows);
builder.field(WEIGHTS.getPreferredName(), weights);
builder.field(BIAS.getPreferredName(), bias);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
LangNetLayer that = (LangNetLayer) o;
return Arrays.equals(weights, that.weights)
&& Arrays.equals(bias, that.bias)
&& Objects.equals(weightCols, that.weightCols)
&& Objects.equals(weightRows, that.weightRows);
}
@Override
public int hashCode() {
return Objects.hash(Arrays.hashCode(weights), Arrays.hashCode(bias), weightCols, weightRows);
}
}

View File

@ -151,6 +151,7 @@ import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.client.ml.inference.TrainedModelStats;
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.client.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
import org.elasticsearch.client.ml.job.config.DataDescription;
import org.elasticsearch.client.ml.job.config.Detector;
@ -201,6 +202,7 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.hasItems;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.not;
@ -2308,6 +2310,21 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(0));
}
public void testGetPrepackagedModels() throws Exception {
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
GetTrainedModelsResponse getTrainedModelsResponse = execute(
new GetTrainedModelsRequest("lang_ident_model_1").setIncludeDefinition(true),
machineLearningClient::getTrainedModels,
machineLearningClient::getTrainedModelsAsync);
assertThat(getTrainedModelsResponse.getCount(), equalTo(1L));
assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(1));
assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo("lang_ident_model_1"));
assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getDefinition().getTrainedModel(),
instanceOf(LangIdentNeuralNetwork.class));
}
public void testPutFilter() throws Exception {
String filterId = "filter-job-test";
MlFilter mlFilter = MlFilter.builder(filterId)

View File

@ -68,6 +68,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Binar
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding;
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncoding;
@ -75,6 +76,7 @@ import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
import org.elasticsearch.client.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.client.transform.transforms.SyncConfig;
import org.elasticsearch.client.transform.transforms.TimeSyncConfig;
@ -688,7 +690,7 @@ public class RestHighLevelClientTests extends ESTestCase {
public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(55, namedXContents.size());
assertEquals(57, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@ -760,10 +762,10 @@ public class RestHighLevelClientTests extends ESTestCase {
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME));
assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
assertThat(names, hasItems(Tree.NAME, Ensemble.NAME));
assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME));
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
assertThat(names, hasItems(Tree.NAME, Ensemble.NAME, LangIdentNeuralNetwork.NAME));
assertEquals(Integer.valueOf(3),
categories.get(org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator.class));
assertThat(names, hasItems(WeightedMode.NAME, WeightedSum.NAME, LogisticRegression.NAME));

View File

@ -0,0 +1,64 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.preprocessing;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class CustomWordEmbeddingTests extends AbstractXContentTestCase<CustomWordEmbedding> {
@Override
protected CustomWordEmbedding doParseInstance(XContentParser parser) throws IOException {
return CustomWordEmbedding.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected CustomWordEmbedding createTestInstance() {
return createRandom();
}
public static CustomWordEmbedding createRandom() {
int quantileSize = randomIntBetween(1, 10);
int internalQuantSize = randomIntBetween(1, 10);
short[][] quantiles = new short[quantileSize][internalQuantSize];
for (int i = 0; i < quantileSize; i++) {
for (int j = 0; j < internalQuantSize; j++) {
quantiles[i][j] = randomShort();
}
}
int weightsSize = randomIntBetween(1, 10);
int internalWeightsSize = randomIntBetween(1, 10);
byte[][] weights = new byte[weightsSize][internalWeightsSize];
for (int i = 0; i < weightsSize; i++) {
for (int j = 0; j < internalWeightsSize; j++) {
weights[i][j] = randomByte();
}
}
return new CustomWordEmbedding(quantiles, weights, randomAlphaOfLength(10), randomAlphaOfLength(10));
}
}

View File

@ -0,0 +1,57 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel.langident;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import java.io.IOException;
public class LangIdentNeuralNetworkTests extends AbstractXContentTestCase<LangIdentNeuralNetwork> {
@Override
protected LangIdentNeuralNetwork doParseInstance(XContentParser parser) throws IOException {
return LangIdentNeuralNetwork.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected LangIdentNeuralNetwork createTestInstance() {
return createRandom();
}
public static LangIdentNeuralNetwork createRandom() {
return new LangIdentNeuralNetwork(randomAlphaOfLength(10),
LangNetLayerTests.createRandom(),
LangNetLayerTests.createRandom());
}
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
}
}

View File

@ -0,0 +1,55 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel.langident;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.stream.Stream;
public class LangNetLayerTests extends AbstractXContentTestCase<LangNetLayer> {
@Override
protected LangNetLayer doParseInstance(XContentParser parser) throws IOException {
return LangNetLayer.PARSER.apply(parser, null);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected LangNetLayer createTestInstance() {
return createRandom();
}
public static LangNetLayer createRandom() {
int numWeights = randomIntBetween(1, 1000);
return new LangNetLayer(
Stream.generate(ESTestCase::randomDouble).limit(numWeights).mapToDouble(Double::doubleValue).toArray(),
numWeights,
1,
Stream.generate(ESTestCase::randomDouble).limit(numWeights).mapToDouble(Double::doubleValue).toArray());
}
}

View File

@ -198,6 +198,16 @@ public class TrainedModelIT extends ESRestTestCase {
assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404));
}
public void testGetPrePackagedModels() throws IOException {
Response getModel = client().performRequest(new Request("GET",
MachineLearning.BASE_PATH + "inference/lang_ident_model_1?human=true&include_model_definition=true"));
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
String response = EntityUtils.toString(getModel.getEntity());
assertThat(response, containsString("lang_ident_model_1"));
assertThat(response, containsString("\"definition\""));
}
private static String buildRegressionModel(String modelId) throws IOException {
try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
TrainedModelConfig.builder()