[ML][Inference] adjust so target_field always has inference result and optionally allow new top classes field in the classification config (#49923) (#49982)
This commit is contained in:
parent
62e128f02d
commit
049d854360
|
@ -12,6 +12,8 @@ import org.elasticsearch.common.io.stream.Writeable;
|
|||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -90,13 +92,15 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
}
|
||||
|
||||
@Override
|
||||
public void writeResult(IngestDocument document, String resultField) {
|
||||
public void writeResult(IngestDocument document, String resultField, InferenceConfig config) {
|
||||
assert config instanceof ClassificationConfig;
|
||||
ClassificationConfig classificationConfig = (ClassificationConfig)config;
|
||||
ExceptionsHelper.requireNonNull(document, "document");
|
||||
ExceptionsHelper.requireNonNull(resultField, "resultField");
|
||||
if (topClasses.isEmpty()) {
|
||||
document.setFieldValue(resultField, valueAsString());
|
||||
} else {
|
||||
document.setFieldValue(resultField, topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
|
||||
document.setFieldValue(resultField, valueAsString());
|
||||
if (topClasses.isEmpty() == false) {
|
||||
document.setFieldValue(classificationConfig.getTopClassesResultsField(),
|
||||
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -7,10 +7,11 @@ package org.elasticsearch.xpack.core.ml.inference.results;
|
|||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
|
||||
|
||||
public interface InferenceResults extends NamedXContentObject, NamedWriteable {
|
||||
|
||||
void writeResult(IngestDocument document, String resultField);
|
||||
void writeResult(IngestDocument document, String resultField, InferenceConfig config);
|
||||
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
|
|||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
@ -49,7 +50,7 @@ public class RawInferenceResults extends SingleValueInferenceResults {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void writeResult(IngestDocument document, String resultField) {
|
||||
public void writeResult(IngestDocument document, String resultField, InferenceConfig config) {
|
||||
throw new UnsupportedOperationException("[raw] does not support writing inference results");
|
||||
}
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
|
|||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -50,7 +51,7 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void writeResult(IngestDocument document, String resultField) {
|
||||
public void writeResult(IngestDocument document, String resultField, InferenceConfig config) {
|
||||
ExceptionsHelper.requireNonNull(document, "document");
|
||||
ExceptionsHelper.requireNonNull(resultField, "resultField");
|
||||
document.setFieldValue(resultField, value());
|
||||
|
|
|
@ -21,37 +21,52 @@ public class ClassificationConfig implements InferenceConfig {
|
|||
|
||||
public static final String NAME = "classification";
|
||||
|
||||
public static final String DEFAULT_TOP_CLASSES_RESULT_FIELD = "top_classes";
|
||||
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
||||
public static final ParseField TOP_CLASSES_RESULT_FIELD = new ParseField("top_classes_result_field");
|
||||
private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0;
|
||||
|
||||
public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig(0);
|
||||
public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig(0, DEFAULT_TOP_CLASSES_RESULT_FIELD);
|
||||
|
||||
private final int numTopClasses;
|
||||
private final String topClassesResultsField;
|
||||
|
||||
public static ClassificationConfig fromMap(Map<String, Object> map) {
|
||||
Map<String, Object> options = new HashMap<>(map);
|
||||
Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName());
|
||||
String topClassesResultsField = (String)options.remove(TOP_CLASSES_RESULT_FIELD.getPreferredName());
|
||||
if (options.isEmpty() == false) {
|
||||
throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
|
||||
}
|
||||
return new ClassificationConfig(numTopClasses);
|
||||
return new ClassificationConfig(numTopClasses, topClassesResultsField);
|
||||
}
|
||||
|
||||
public ClassificationConfig(Integer numTopClasses) {
|
||||
this(numTopClasses, null);
|
||||
}
|
||||
|
||||
public ClassificationConfig(Integer numTopClasses, String topClassesResultsField) {
|
||||
this.numTopClasses = numTopClasses == null ? 0 : numTopClasses;
|
||||
this.topClassesResultsField = topClassesResultsField == null ? DEFAULT_TOP_CLASSES_RESULT_FIELD : topClassesResultsField;
|
||||
}
|
||||
|
||||
public ClassificationConfig(StreamInput in) throws IOException {
|
||||
this.numTopClasses = in.readInt();
|
||||
this.topClassesResultsField = in.readString();
|
||||
}
|
||||
|
||||
public int getNumTopClasses() {
|
||||
return numTopClasses;
|
||||
}
|
||||
|
||||
public String getTopClassesResultsField() {
|
||||
return topClassesResultsField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeInt(numTopClasses);
|
||||
out.writeString(topClassesResultsField);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -59,12 +74,12 @@ public class ClassificationConfig implements InferenceConfig {
|
|||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ClassificationConfig that = (ClassificationConfig) o;
|
||||
return Objects.equals(numTopClasses, that.numTopClasses);
|
||||
return Objects.equals(numTopClasses, that.numTopClasses) && Objects.equals(topClassesResultsField, that.topClassesResultsField);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(numTopClasses);
|
||||
return Objects.hash(numTopClasses, topClassesResultsField);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -73,6 +88,7 @@ public class ClassificationConfig implements InferenceConfig {
|
|||
if (numTopClasses != 0) {
|
||||
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
|
||||
}
|
||||
builder.field(TOP_CLASSES_RESULT_FIELD.getPreferredName(), topClassesResultsField);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.results;
|
|||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
@ -37,7 +38,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
|||
public void testWriteResultsWithClassificationLabel() {
|
||||
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, "foo", Collections.emptyList());
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
result.writeResult(document, "result_field", ClassificationConfig.EMPTY_PARAMS);
|
||||
|
||||
assertThat(document.getFieldValue("result_field", String.class), equalTo("foo"));
|
||||
}
|
||||
|
@ -45,7 +46,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
|||
public void testWriteResultsWithoutClassificationLabel() {
|
||||
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, null, Collections.emptyList());
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
result.writeResult(document, "result_field", ClassificationConfig.EMPTY_PARAMS);
|
||||
|
||||
assertThat(document.getFieldValue("result_field", String.class), equalTo("1.0"));
|
||||
}
|
||||
|
@ -60,15 +61,17 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
|||
"foo",
|
||||
entries);
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
result.writeResult(document, "result_field", new ClassificationConfig(3, "bar"));
|
||||
|
||||
List<?> list = document.getFieldValue("result_field", List.class);
|
||||
List<?> list = document.getFieldValue("bar", List.class);
|
||||
assertThat(list.size(), equalTo(3));
|
||||
|
||||
for(int i = 0; i < 3; i++) {
|
||||
Map<String, Object> map = (Map<String, Object>)list.get(i);
|
||||
assertThat(map, equalTo(entries.get(i).asValueMap()));
|
||||
}
|
||||
|
||||
assertThat(document.getFieldValue("result_field", String.class), equalTo("foo"));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -9,6 +9,7 @@ import org.elasticsearch.common.io.stream.Writeable;
|
|||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
|
@ -24,7 +25,7 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
|
|||
public void testWriteResults() {
|
||||
RegressionInferenceResults result = new RegressionInferenceResults(0.3);
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
result.writeResult(document, "result_field", new RegressionConfig());
|
||||
|
||||
assertThat(document.getFieldValue("result_field", Double.class), equalTo(0.3));
|
||||
}
|
||||
|
|
|
@ -10,22 +10,27 @@ import org.elasticsearch.common.io.stream.Writeable;
|
|||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class ClassificationConfigTests extends AbstractWireSerializingTestCase<ClassificationConfig> {
|
||||
|
||||
public static ClassificationConfig randomClassificationConfig() {
|
||||
return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10));
|
||||
return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10),
|
||||
randomBoolean() ? null : randomAlphaOfLength(10));
|
||||
}
|
||||
|
||||
public void testFromMap() {
|
||||
ClassificationConfig expected = new ClassificationConfig(0);
|
||||
ClassificationConfig expected = ClassificationConfig.EMPTY_PARAMS;
|
||||
assertThat(ClassificationConfig.fromMap(Collections.emptyMap()), equalTo(expected));
|
||||
|
||||
expected = new ClassificationConfig(3);
|
||||
assertThat(ClassificationConfig.fromMap(Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3)),
|
||||
equalTo(expected));
|
||||
expected = new ClassificationConfig(3, "foo");
|
||||
Map<String, Object> configMap = new HashMap<>();
|
||||
configMap.put(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3);
|
||||
configMap.put(ClassificationConfig.TOP_CLASSES_RESULT_FIELD.getPreferredName(), "foo");
|
||||
assertThat(ClassificationConfig.fromMap(configMap), equalTo(expected));
|
||||
}
|
||||
|
||||
public void testFromMapWithUnknownField() {
|
||||
|
|
|
@ -148,20 +148,8 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
|
|||
" {\n" +
|
||||
" \"inference\": {\n" +
|
||||
" \"target_field\": \"result_class\",\n" +
|
||||
" \"inference_config\": {\"classification\":{}},\n" +
|
||||
" \"model_id\": \"test_classification\",\n" +
|
||||
" \"field_mappings\": {\n" +
|
||||
" \"col1\": \"col1\",\n" +
|
||||
" \"col2\": \"col2\",\n" +
|
||||
" \"col3\": \"col3\",\n" +
|
||||
" \"col4\": \"col4\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"inference\": {\n" +
|
||||
" \"target_field\": \"result_class_prob\",\n" +
|
||||
" \"inference_config\": {\"classification\": {\"num_top_classes\":2}},\n" +
|
||||
" \"inference_config\": {\"classification\": " +
|
||||
" {\"num_top_classes\":2, \"top_classes_result_field\": \"result_class_prob\"}},\n" +
|
||||
" \"model_id\": \"test_classification\",\n" +
|
||||
" \"field_mappings\": {\n" +
|
||||
" \"col1\": \"col1\",\n" +
|
||||
|
|
|
@ -35,7 +35,6 @@ import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
|||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -154,7 +153,7 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
if (response.getInferenceResults().isEmpty()) {
|
||||
throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField);
|
||||
response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField, inferenceConfig);
|
||||
if (includeModelMetadata) {
|
||||
ingestDocument.setFieldValue(modelInfoField + "." + MODEL_ID, modelId);
|
||||
}
|
||||
|
@ -227,8 +226,7 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
}
|
||||
|
||||
@Override
|
||||
public InferenceProcessor create(Map<String, Processor.Factory> processorFactories, String tag, Map<String, Object> config)
|
||||
throws Exception {
|
||||
public InferenceProcessor create(Map<String, Processor.Factory> processorFactories, String tag, Map<String, Object> config) {
|
||||
|
||||
if (this.maxIngestProcessors <= currentInferenceProcessors) {
|
||||
throw new ElasticsearchStatusException("Max number of inference processors reached, total inference processors [{}]. " +
|
||||
|
@ -267,7 +265,7 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
this.maxIngestProcessors = maxIngestProcessors;
|
||||
}
|
||||
|
||||
InferenceConfig inferenceConfigFromMap(Map<String, Object> inferenceConfig) throws IOException {
|
||||
InferenceConfig inferenceConfigFromMap(Map<String, Object> inferenceConfig) {
|
||||
ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG);
|
||||
|
||||
if (inferenceConfig.size() != 1) {
|
||||
|
@ -284,7 +282,7 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
Map<String, Object> valueMap = (Map<String, Object>)value;
|
||||
|
||||
if (inferenceConfig.containsKey(ClassificationConfig.NAME)) {
|
||||
checkSupportedVersion(new ClassificationConfig(0));
|
||||
checkSupportedVersion(ClassificationConfig.EMPTY_PARAMS);
|
||||
return ClassificationConfig.fromMap(valueMap);
|
||||
} else if (inferenceConfig.containsKey(RegressionConfig.NAME)) {
|
||||
checkSupportedVersion(new RegressionConfig());
|
||||
|
|
|
@ -51,7 +51,7 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
"my_processor",
|
||||
targetField,
|
||||
"classification_model",
|
||||
new ClassificationConfig(0),
|
||||
ClassificationConfig.EMPTY_PARAMS,
|
||||
Collections.emptyMap(),
|
||||
"ml.my_processor",
|
||||
true);
|
||||
|
@ -78,7 +78,7 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
"my_processor",
|
||||
targetField,
|
||||
"classification_model",
|
||||
new ClassificationConfig(2),
|
||||
new ClassificationConfig(2, null),
|
||||
Collections.emptyMap(),
|
||||
"ml.my_processor",
|
||||
true);
|
||||
|
@ -96,10 +96,44 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
true);
|
||||
inferenceProcessor.mutateDocument(response, document);
|
||||
|
||||
assertThat((List<Map<?,?>>)document.getFieldValue(targetField, List.class),
|
||||
assertThat((List<Map<?,?>>)document.getFieldValue(ClassificationConfig.DEFAULT_TOP_CLASSES_RESULT_FIELD, List.class),
|
||||
contains(classes.stream().map(ClassificationInferenceResults.TopClassEntry::asValueMap).toArray(Map[]::new)));
|
||||
assertThat(document.getFieldValue("ml", Map.class),
|
||||
equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model"))));
|
||||
assertThat(document.getFieldValue(targetField, String.class), equalTo("foo"));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testMutateDocumentClassificationTopNClassesWithSpecificField() {
|
||||
String targetField = "classification_value_probabilities";
|
||||
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
|
||||
auditor,
|
||||
"my_processor",
|
||||
targetField,
|
||||
"classification_model",
|
||||
new ClassificationConfig(2, "my_top_classes"),
|
||||
Collections.emptyMap(),
|
||||
"ml.my_processor",
|
||||
true);
|
||||
|
||||
Map<String, Object> source = new HashMap<>();
|
||||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
List<ClassificationInferenceResults.TopClassEntry> classes = new ArrayList<>(2);
|
||||
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6));
|
||||
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4));
|
||||
|
||||
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
|
||||
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes)),
|
||||
true);
|
||||
inferenceProcessor.mutateDocument(response, document);
|
||||
|
||||
assertThat((List<Map<?,?>>)document.getFieldValue("my_top_classes", List.class),
|
||||
contains(classes.stream().map(ClassificationInferenceResults.TopClassEntry::asValueMap).toArray(Map[]::new)));
|
||||
assertThat(document.getFieldValue("ml", Map.class),
|
||||
equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model"))));
|
||||
assertThat(document.getFieldValue(targetField, String.class), equalTo("foo"));
|
||||
}
|
||||
|
||||
public void testMutateDocumentRegression() {
|
||||
|
@ -194,7 +228,7 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
"my_processor",
|
||||
"my_field",
|
||||
modelId,
|
||||
new ClassificationConfig(topNClasses),
|
||||
new ClassificationConfig(topNClasses, null),
|
||||
Collections.emptyMap(),
|
||||
"ml.my_processor",
|
||||
false);
|
||||
|
@ -225,7 +259,7 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
"my_processor",
|
||||
"my_field",
|
||||
modelId,
|
||||
new ClassificationConfig(topNClasses),
|
||||
new ClassificationConfig(topNClasses, null),
|
||||
fieldMapping,
|
||||
"ml.my_processor",
|
||||
false);
|
||||
|
|
|
@ -131,7 +131,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
|||
|
||||
|
||||
// Test classification
|
||||
request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfig(0), true);
|
||||
request = new InternalInferModelAction.Request(modelId2, toInfer, ClassificationConfig.EMPTY_PARAMS, true);
|
||||
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
|
||||
assertThat(response.getInferenceResults()
|
||||
.stream()
|
||||
|
@ -140,7 +140,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
|||
contains("not_to_be", "to_be"));
|
||||
|
||||
// Get top classes
|
||||
request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfig(2), true);
|
||||
request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfig(2, null), true);
|
||||
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
|
||||
|
||||
ClassificationInferenceResults classificationInferenceResults =
|
||||
|
@ -159,7 +159,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
|||
greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability()));
|
||||
|
||||
// Test that top classes restrict the number returned
|
||||
request = new InternalInferModelAction.Request(modelId2, toInfer2, new ClassificationConfig(1), true);
|
||||
request = new InternalInferModelAction.Request(modelId2, toInfer2, new ClassificationConfig(1, null), true);
|
||||
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
|
||||
|
||||
classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0);
|
||||
|
|
Loading…
Reference in New Issue