[ML] Add parsers for inference configuration classes (#51300)

This commit is contained in:
David Kyle 2020-01-22 17:03:01 +00:00 committed by GitHub
parent 4590d4156a
commit 0ac03ac5e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 84 additions and 25 deletions

View File

@ -526,8 +526,10 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
RegressionInferenceResults.NAME,
RegressionInferenceResults::new),
// ML - Inference Configuration
new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::new),
new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME, RegressionConfig::new),
new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME.getPreferredName(),
ClassificationConfig::new),
new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME.getPreferredName(),
RegressionConfig::new),
// monitoring
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new),
@ -591,7 +593,7 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
new NamedWriteableRegistry.Entry(LifecycleAction.class, SetPriorityAction.NAME, SetPriorityAction::new),
new NamedWriteableRegistry.Entry(LifecycleAction.class, UnfollowAction.NAME, UnfollowAction::new),
new NamedWriteableRegistry.Entry(LifecycleAction.class, WaitForSnapshotAction.NAME, WaitForSnapshotAction::new),
// Data Frame
// Transforms
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.TRANSFORM, TransformFeatureSetUsage::new),
new NamedWriteableRegistry.Entry(PersistentTaskParams.class, TransformField.TASK_NAME, TransformTaskParams::new),
new NamedWriteableRegistry.Entry(Task.Status.class, TransformField.TASK_NAME, TransformState::new),
@ -647,7 +649,7 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
RollupJobStatus::fromXContent),
new NamedXContentRegistry.Entry(PersistentTaskState.class, new ParseField(RollupJobStatus.NAME),
RollupJobStatus::fromXContent),
// Data Frame
// Transforms
new NamedXContentRegistry.Entry(PersistentTaskParams.class, new ParseField(TransformField.TASK_NAME),
TransformTaskParams::fromXContent),
new NamedXContentRegistry.Entry(Task.Status.class, new ParseField(TransformField.TASK_NAME),

View File

@ -99,6 +99,12 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
LogisticRegression.NAME,
LogisticRegression::fromXContentStrict));
// Inference Configs
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME,
ClassificationConfig::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME,
RegressionConfig::fromXContent));
return namedXContent;
}
@ -142,8 +148,10 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
RegressionInferenceResults::new));
// Inference Configs
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME, RegressionConfig::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
ClassificationConfig.NAME.getPreferredName(), ClassificationConfig::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
RegressionConfig.NAME.getPreferredName(), RegressionConfig::new));
return namedWriteables;
}

View File

@ -9,7 +9,9 @@ import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
@ -17,12 +19,15 @@ import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class ClassificationConfig implements InferenceConfig {
public static final String NAME = "classification";
public static final ParseField NAME = new ParseField("classification");
public static final String DEFAULT_TOP_CLASSES_RESULTS_FIELD = "top_classes";
private static final String DEFAULT_RESULTS_FIELD = "predicted_value";
public static final ParseField RESULTS_FIELD = new ParseField("results_field");
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
public static final ParseField TOP_CLASSES_RESULTS_FIELD = new ParseField("top_classes_results_field");
@ -45,6 +50,20 @@ public class ClassificationConfig implements InferenceConfig {
return new ClassificationConfig(numTopClasses, resultsField, topClassesResultsField);
}
private static final ConstructingObjectParser<ClassificationConfig, Void> PARSER =
new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new ClassificationConfig(
(Integer) args[0], (String) args[1], (String) args[2]));
static {
PARSER.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD);
PARSER.declareString(optionalConstructorArg(), TOP_CLASSES_RESULTS_FIELD);
}
public static ClassificationConfig fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
public ClassificationConfig(Integer numTopClasses) {
this(numTopClasses, null, null);
}
@ -109,12 +128,12 @@ public class ClassificationConfig implements InferenceConfig {
@Override
public String getWriteableName() {
return NAME;
return NAME.getPreferredName();
}
@Override
public String getName() {
return NAME;
return NAME.getPreferredName();
}
@Override

View File

@ -9,7 +9,9 @@ import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
@ -17,9 +19,11 @@ import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class RegressionConfig implements InferenceConfig {
public static final String NAME = "regression";
public static final ParseField NAME = new ParseField("regression");
private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0;
public static final ParseField RESULTS_FIELD = new ParseField("results_field");
private static final String DEFAULT_RESULTS_FIELD = "predicted_value";
@ -35,6 +39,17 @@ public class RegressionConfig implements InferenceConfig {
return new RegressionConfig(resultsField);
}
private static final ConstructingObjectParser<RegressionConfig, Void> PARSER =
new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new RegressionConfig((String) args[0]));
static {
PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD);
}
public static RegressionConfig fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final String resultsField;
public RegressionConfig(String resultsField) {
@ -51,7 +66,7 @@ public class RegressionConfig implements InferenceConfig {
@Override
public String getWriteableName() {
return NAME;
return NAME.getPreferredName();
}
@Override
@ -61,7 +76,7 @@ public class RegressionConfig implements InferenceConfig {
@Override
public String getName() {
return NAME;
return NAME.getPreferredName();
}
@Override

View File

@ -7,15 +7,17 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
public class ClassificationConfigTests extends AbstractWireSerializingTestCase<ClassificationConfig> {
public class ClassificationConfigTests extends AbstractSerializingTestCase<ClassificationConfig> {
public static ClassificationConfig randomClassificationConfig() {
return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10),
@ -52,4 +54,8 @@ public class ClassificationConfigTests extends AbstractWireSerializingTestCase<C
return ClassificationConfig::new;
}
@Override
protected ClassificationConfig doParseInstance(XContentParser parser) throws IOException {
return ClassificationConfig.fromXContent(parser);
}
}

View File

@ -7,15 +7,17 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
public class RegressionConfigTests extends AbstractWireSerializingTestCase<RegressionConfig> {
public class RegressionConfigTests extends AbstractSerializingTestCase<RegressionConfig> {
public static RegressionConfig randomRegressionConfig() {
return new RegressionConfig(randomBoolean() ? null : randomAlphaOfLength(10));
@ -45,4 +47,8 @@ public class RegressionConfigTests extends AbstractWireSerializingTestCase<Regre
return RegressionConfig::new;
}
@Override
protected RegressionConfig doParseInstance(XContentParser parser) throws IOException {
return RegressionConfig.fromXContent(parser);
}
}

View File

@ -275,12 +275,12 @@ public class InferenceProcessor extends AbstractProcessor {
@SuppressWarnings("unchecked")
Map<String, Object> valueMap = (Map<String, Object>)value;
if (inferenceConfig.containsKey(ClassificationConfig.NAME)) {
if (inferenceConfig.containsKey(ClassificationConfig.NAME.getPreferredName())) {
checkSupportedVersion(ClassificationConfig.EMPTY_PARAMS);
ClassificationConfig config = ClassificationConfig.fromMap(valueMap);
checkFieldUniqueness(config.getResultsField(), config.getTopClassesResultsField());
return config;
} else if (inferenceConfig.containsKey(RegressionConfig.NAME)) {
} else if (inferenceConfig.containsKey(RegressionConfig.NAME.getPreferredName())) {
checkSupportedVersion(RegressionConfig.EMPTY_PARAMS);
RegressionConfig config = RegressionConfig.fromMap(valueMap);
checkFieldUniqueness(config.getResultsField());
@ -288,7 +288,7 @@ public class InferenceProcessor extends AbstractProcessor {
} else {
throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}",
inferenceConfig.keySet(),
Arrays.asList(ClassificationConfig.NAME, RegressionConfig.NAME));
Arrays.asList(ClassificationConfig.NAME.getPreferredName(), RegressionConfig.NAME.getPreferredName()));
}
}

View File

@ -178,7 +178,8 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
put(InferenceProcessor.MODEL_ID, "my_model");
put(InferenceProcessor.TARGET_FIELD, "result");
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap()));
put(InferenceProcessor.INFERENCE_CONFIG,
Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.emptyMap()));
}};
try {
@ -195,7 +196,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
put(InferenceProcessor.MODEL_ID, "my_model");
put(InferenceProcessor.TARGET_FIELD, "result");
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME,
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME.getPreferredName(),
Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 1)));
}};
@ -220,7 +221,8 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
put(InferenceProcessor.MODEL_ID, "my_model");
put(InferenceProcessor.TARGET_FIELD, "result");
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap()));
put(InferenceProcessor.INFERENCE_CONFIG,
Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.emptyMap()));
}};
try {
@ -233,7 +235,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
put(InferenceProcessor.MODEL_ID, "my_model");
put(InferenceProcessor.TARGET_FIELD, "result");
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME,
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME.getPreferredName(),
Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 1)));
}};
@ -254,7 +256,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
put(InferenceProcessor.MODEL_ID, "my_model");
put(InferenceProcessor.TARGET_FIELD, "ml");
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME,
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME.getPreferredName(),
Collections.singletonMap(RegressionConfig.RESULTS_FIELD.getPreferredName(), "warning")));
}};
@ -302,7 +304,8 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
Collections.singletonMap(InferenceProcessor.TYPE,
new HashMap<String, Object>() {{
put(InferenceProcessor.MODEL_ID, modelId);
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap()));
put(InferenceProcessor.INFERENCE_CONFIG,
Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.emptyMap()));
put(InferenceProcessor.TARGET_FIELD, "new_field");
put(InferenceProcessor.FIELD_MAPPINGS, Collections.singletonMap("source", "dest"));
}}))))) {