[ML] Add parsers for inference configuration classes (#51300)
This commit is contained in:
parent
4590d4156a
commit
0ac03ac5e7
|
@ -526,8 +526,10 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
|
|||
RegressionInferenceResults.NAME,
|
||||
RegressionInferenceResults::new),
|
||||
// ML - Inference Configuration
|
||||
new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::new),
|
||||
new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME, RegressionConfig::new),
|
||||
new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME.getPreferredName(),
|
||||
ClassificationConfig::new),
|
||||
new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME.getPreferredName(),
|
||||
RegressionConfig::new),
|
||||
|
||||
// monitoring
|
||||
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new),
|
||||
|
@ -591,7 +593,7 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
|
|||
new NamedWriteableRegistry.Entry(LifecycleAction.class, SetPriorityAction.NAME, SetPriorityAction::new),
|
||||
new NamedWriteableRegistry.Entry(LifecycleAction.class, UnfollowAction.NAME, UnfollowAction::new),
|
||||
new NamedWriteableRegistry.Entry(LifecycleAction.class, WaitForSnapshotAction.NAME, WaitForSnapshotAction::new),
|
||||
// Data Frame
|
||||
// Transforms
|
||||
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.TRANSFORM, TransformFeatureSetUsage::new),
|
||||
new NamedWriteableRegistry.Entry(PersistentTaskParams.class, TransformField.TASK_NAME, TransformTaskParams::new),
|
||||
new NamedWriteableRegistry.Entry(Task.Status.class, TransformField.TASK_NAME, TransformState::new),
|
||||
|
@ -647,7 +649,7 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
|
|||
RollupJobStatus::fromXContent),
|
||||
new NamedXContentRegistry.Entry(PersistentTaskState.class, new ParseField(RollupJobStatus.NAME),
|
||||
RollupJobStatus::fromXContent),
|
||||
// Data Frame
|
||||
// Transforms
|
||||
new NamedXContentRegistry.Entry(PersistentTaskParams.class, new ParseField(TransformField.TASK_NAME),
|
||||
TransformTaskParams::fromXContent),
|
||||
new NamedXContentRegistry.Entry(Task.Status.class, new ParseField(TransformField.TASK_NAME),
|
||||
|
|
|
@ -99,6 +99,12 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
|
|||
LogisticRegression.NAME,
|
||||
LogisticRegression::fromXContentStrict));
|
||||
|
||||
// Inference Configs
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME,
|
||||
ClassificationConfig::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME,
|
||||
RegressionConfig::fromXContent));
|
||||
|
||||
return namedXContent;
|
||||
}
|
||||
|
||||
|
@ -142,8 +148,10 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
|
|||
RegressionInferenceResults::new));
|
||||
|
||||
// Inference Configs
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME, RegressionConfig::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
|
||||
ClassificationConfig.NAME.getPreferredName(), ClassificationConfig::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
|
||||
RegressionConfig.NAME.getPreferredName(), RegressionConfig::new));
|
||||
|
||||
return namedWriteables;
|
||||
}
|
||||
|
|
|
@ -9,7 +9,9 @@ import org.elasticsearch.Version;
|
|||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -17,12 +19,15 @@ import java.util.HashMap;
|
|||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
public class ClassificationConfig implements InferenceConfig {
|
||||
|
||||
public static final String NAME = "classification";
|
||||
public static final ParseField NAME = new ParseField("classification");
|
||||
|
||||
public static final String DEFAULT_TOP_CLASSES_RESULTS_FIELD = "top_classes";
|
||||
private static final String DEFAULT_RESULTS_FIELD = "predicted_value";
|
||||
|
||||
public static final ParseField RESULTS_FIELD = new ParseField("results_field");
|
||||
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
||||
public static final ParseField TOP_CLASSES_RESULTS_FIELD = new ParseField("top_classes_results_field");
|
||||
|
@ -45,6 +50,20 @@ public class ClassificationConfig implements InferenceConfig {
|
|||
return new ClassificationConfig(numTopClasses, resultsField, topClassesResultsField);
|
||||
}
|
||||
|
||||
private static final ConstructingObjectParser<ClassificationConfig, Void> PARSER =
|
||||
new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new ClassificationConfig(
|
||||
(Integer) args[0], (String) args[1], (String) args[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
|
||||
PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD);
|
||||
PARSER.declareString(optionalConstructorArg(), TOP_CLASSES_RESULTS_FIELD);
|
||||
}
|
||||
|
||||
public static ClassificationConfig fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public ClassificationConfig(Integer numTopClasses) {
|
||||
this(numTopClasses, null, null);
|
||||
}
|
||||
|
@ -109,12 +128,12 @@ public class ClassificationConfig implements InferenceConfig {
|
|||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -9,7 +9,9 @@ import org.elasticsearch.Version;
|
|||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -17,9 +19,11 @@ import java.util.HashMap;
|
|||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
public class RegressionConfig implements InferenceConfig {
|
||||
|
||||
public static final String NAME = "regression";
|
||||
public static final ParseField NAME = new ParseField("regression");
|
||||
private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0;
|
||||
public static final ParseField RESULTS_FIELD = new ParseField("results_field");
|
||||
private static final String DEFAULT_RESULTS_FIELD = "predicted_value";
|
||||
|
@ -35,6 +39,17 @@ public class RegressionConfig implements InferenceConfig {
|
|||
return new RegressionConfig(resultsField);
|
||||
}
|
||||
|
||||
private static final ConstructingObjectParser<RegressionConfig, Void> PARSER =
|
||||
new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new RegressionConfig((String) args[0]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD);
|
||||
}
|
||||
|
||||
public static RegressionConfig fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final String resultsField;
|
||||
|
||||
public RegressionConfig(String resultsField) {
|
||||
|
@ -51,7 +66,7 @@ public class RegressionConfig implements InferenceConfig {
|
|||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -61,7 +76,7 @@ public class RegressionConfig implements InferenceConfig {
|
|||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -7,15 +7,17 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
|||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class ClassificationConfigTests extends AbstractWireSerializingTestCase<ClassificationConfig> {
|
||||
public class ClassificationConfigTests extends AbstractSerializingTestCase<ClassificationConfig> {
|
||||
|
||||
public static ClassificationConfig randomClassificationConfig() {
|
||||
return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10),
|
||||
|
@ -52,4 +54,8 @@ public class ClassificationConfigTests extends AbstractWireSerializingTestCase<C
|
|||
return ClassificationConfig::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ClassificationConfig doParseInstance(XContentParser parser) throws IOException {
|
||||
return ClassificationConfig.fromXContent(parser);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,15 +7,17 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
|||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class RegressionConfigTests extends AbstractWireSerializingTestCase<RegressionConfig> {
|
||||
public class RegressionConfigTests extends AbstractSerializingTestCase<RegressionConfig> {
|
||||
|
||||
public static RegressionConfig randomRegressionConfig() {
|
||||
return new RegressionConfig(randomBoolean() ? null : randomAlphaOfLength(10));
|
||||
|
@ -45,4 +47,8 @@ public class RegressionConfigTests extends AbstractWireSerializingTestCase<Regre
|
|||
return RegressionConfig::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RegressionConfig doParseInstance(XContentParser parser) throws IOException {
|
||||
return RegressionConfig.fromXContent(parser);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -275,12 +275,12 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> valueMap = (Map<String, Object>)value;
|
||||
|
||||
if (inferenceConfig.containsKey(ClassificationConfig.NAME)) {
|
||||
if (inferenceConfig.containsKey(ClassificationConfig.NAME.getPreferredName())) {
|
||||
checkSupportedVersion(ClassificationConfig.EMPTY_PARAMS);
|
||||
ClassificationConfig config = ClassificationConfig.fromMap(valueMap);
|
||||
checkFieldUniqueness(config.getResultsField(), config.getTopClassesResultsField());
|
||||
return config;
|
||||
} else if (inferenceConfig.containsKey(RegressionConfig.NAME)) {
|
||||
} else if (inferenceConfig.containsKey(RegressionConfig.NAME.getPreferredName())) {
|
||||
checkSupportedVersion(RegressionConfig.EMPTY_PARAMS);
|
||||
RegressionConfig config = RegressionConfig.fromMap(valueMap);
|
||||
checkFieldUniqueness(config.getResultsField());
|
||||
|
@ -288,7 +288,7 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
} else {
|
||||
throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}",
|
||||
inferenceConfig.keySet(),
|
||||
Arrays.asList(ClassificationConfig.NAME, RegressionConfig.NAME));
|
||||
Arrays.asList(ClassificationConfig.NAME.getPreferredName(), RegressionConfig.NAME.getPreferredName()));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -178,7 +178,8 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "result");
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap()));
|
||||
put(InferenceProcessor.INFERENCE_CONFIG,
|
||||
Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.emptyMap()));
|
||||
}};
|
||||
|
||||
try {
|
||||
|
@ -195,7 +196,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "result");
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME,
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME.getPreferredName(),
|
||||
Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 1)));
|
||||
}};
|
||||
|
||||
|
@ -220,7 +221,8 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "result");
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap()));
|
||||
put(InferenceProcessor.INFERENCE_CONFIG,
|
||||
Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.emptyMap()));
|
||||
}};
|
||||
|
||||
try {
|
||||
|
@ -233,7 +235,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "result");
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME,
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME.getPreferredName(),
|
||||
Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 1)));
|
||||
}};
|
||||
|
||||
|
@ -254,7 +256,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "ml");
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME,
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME.getPreferredName(),
|
||||
Collections.singletonMap(RegressionConfig.RESULTS_FIELD.getPreferredName(), "warning")));
|
||||
}};
|
||||
|
||||
|
@ -302,7 +304,8 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
Collections.singletonMap(InferenceProcessor.TYPE,
|
||||
new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.MODEL_ID, modelId);
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap()));
|
||||
put(InferenceProcessor.INFERENCE_CONFIG,
|
||||
Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.emptyMap()));
|
||||
put(InferenceProcessor.TARGET_FIELD, "new_field");
|
||||
put(InferenceProcessor.FIELD_MAPPINGS, Collections.singletonMap("source", "dest"));
|
||||
}}))))) {
|
||||
|
|
Loading…
Reference in New Issue