parent
2a4fd8329b
commit
66b3e89482
|
@ -6,8 +6,11 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
|
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
|
||||||
|
|
||||||
|
import org.apache.logging.log4j.LogManager;
|
||||||
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.apache.lucene.util.RamUsageEstimator;
|
import org.apache.lucene.util.RamUsageEstimator;
|
||||||
import org.elasticsearch.common.Nullable;
|
import org.elasticsearch.common.Nullable;
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
import org.elasticsearch.common.collect.Tuple;
|
import org.elasticsearch.common.collect.Tuple;
|
||||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
@ -25,6 +28,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Leniently
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.LinkedHashSet;
|
import java.util.LinkedHashSet;
|
||||||
|
@ -49,6 +53,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.En
|
||||||
public class EnsembleInferenceModel implements InferenceModel {
|
public class EnsembleInferenceModel implements InferenceModel {
|
||||||
|
|
||||||
public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class);
|
public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class);
|
||||||
|
private static final Logger LOGGER = LogManager.getLogger(EnsembleInferenceModel.class);
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
private static final ConstructingObjectParser<EnsembleInferenceModel, Void> PARSER = new ConstructingObjectParser<>(
|
private static final ConstructingObjectParser<EnsembleInferenceModel, Void> PARSER = new ConstructingObjectParser<>(
|
||||||
|
@ -136,6 +141,8 @@ public class EnsembleInferenceModel implements InferenceModel {
|
||||||
if (preparedForInference == false) {
|
if (preparedForInference == false) {
|
||||||
throw ExceptionsHelper.serverError("model is not prepared for inference");
|
throw ExceptionsHelper.serverError("model is not prepared for inference");
|
||||||
}
|
}
|
||||||
|
LOGGER.debug("Inference called with feature names [{}]",
|
||||||
|
featureNames == null ? "<null>" : Strings.arrayToCommaDelimitedString(featureNames));
|
||||||
assert featureNames != null && featureNames.length > 0;
|
assert featureNames != null && featureNames.length > 0;
|
||||||
double[][] inferenceResults = new double[this.models.size()][];
|
double[][] inferenceResults = new double[this.models.size()][];
|
||||||
double[][] featureInfluence = new double[features.length][];
|
double[][] featureInfluence = new double[features.length][];
|
||||||
|
@ -237,12 +244,14 @@ public class EnsembleInferenceModel implements InferenceModel {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
|
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
|
||||||
|
LOGGER.debug("rewriting features {}", newFeatureIndexMapping);
|
||||||
if (preparedForInference) {
|
if (preparedForInference) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
preparedForInference = true;
|
preparedForInference = true;
|
||||||
if (newFeatureIndexMapping == null || newFeatureIndexMapping.isEmpty()) {
|
if (newFeatureIndexMapping == null || newFeatureIndexMapping.isEmpty()) {
|
||||||
Set<String> referencedFeatures = subModelFeatures();
|
Set<String> referencedFeatures = subModelFeatures();
|
||||||
|
LOGGER.debug("detected submodel feature names {}", referencedFeatures);
|
||||||
int newFeatureIndex = 0;
|
int newFeatureIndex = 0;
|
||||||
newFeatureIndexMapping = new HashMap<>();
|
newFeatureIndexMapping = new HashMap<>();
|
||||||
this.featureNames = new String[referencedFeatures.size()];
|
this.featureNames = new String[referencedFeatures.size()];
|
||||||
|
@ -301,4 +310,16 @@ public class EnsembleInferenceModel implements InferenceModel {
|
||||||
return classificationWeights;
|
return classificationWeights;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "EnsembleInferenceModel{" +
|
||||||
|
"featureNames=" + Arrays.toString(featureNames) +
|
||||||
|
", models=" + models +
|
||||||
|
", outputAggregator=" + outputAggregator +
|
||||||
|
", targetType=" + targetType +
|
||||||
|
", classificationLabels=" + classificationLabels +
|
||||||
|
", classificationWeights=" + Arrays.toString(classificationWeights) +
|
||||||
|
", preparedForInference=" + preparedForInference +
|
||||||
|
'}';
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,8 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
|
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
|
||||||
|
|
||||||
|
import org.apache.logging.log4j.LogManager;
|
||||||
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.apache.lucene.util.Accountable;
|
import org.apache.lucene.util.Accountable;
|
||||||
import org.elasticsearch.common.Nullable;
|
import org.elasticsearch.common.Nullable;
|
||||||
import org.elasticsearch.common.Numbers;
|
import org.elasticsearch.common.Numbers;
|
||||||
|
@ -28,6 +30,7 @@ import org.elasticsearch.xpack.core.ml.inference.utils.Statistics;
|
||||||
import org.elasticsearch.xpack.core.ml.job.config.Operator;
|
import org.elasticsearch.xpack.core.ml.job.config.Operator;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
@ -56,6 +59,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNo
|
||||||
|
|
||||||
public class TreeInferenceModel implements InferenceModel {
|
public class TreeInferenceModel implements InferenceModel {
|
||||||
|
|
||||||
|
private static final Logger LOGGER = LogManager.getLogger(TreeInferenceModel.class);
|
||||||
public static final long SHALLOW_SIZE = shallowSizeOfInstance(TreeInferenceModel.class);
|
public static final long SHALLOW_SIZE = shallowSizeOfInstance(TreeInferenceModel.class);
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
|
@ -304,6 +308,7 @@ public class TreeInferenceModel implements InferenceModel {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
|
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
|
||||||
|
LOGGER.debug("rewriting features {}", newFeatureIndexMapping);
|
||||||
if (preparedForInference) {
|
if (preparedForInference) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -358,6 +363,20 @@ public class TreeInferenceModel implements InferenceModel {
|
||||||
return nodes;
|
return nodes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "TreeInferenceModel{" +
|
||||||
|
"nodes=" + Arrays.toString(nodes) +
|
||||||
|
", featureNames=" + Arrays.toString(featureNames) +
|
||||||
|
", targetType=" + targetType +
|
||||||
|
", classificationLabels=" + classificationLabels +
|
||||||
|
", highOrderCategory=" + highOrderCategory +
|
||||||
|
", maxDepth=" + maxDepth +
|
||||||
|
", leafSize=" + leafSize +
|
||||||
|
", preparedForInference=" + preparedForInference +
|
||||||
|
'}';
|
||||||
|
}
|
||||||
|
|
||||||
private static int getDepth(Node[] nodes, int nodeIndex, int depth) {
|
private static int getDepth(Node[] nodes, int nodeIndex, int depth) {
|
||||||
Node node = nodes[nodeIndex];
|
Node node = nodes[nodeIndex];
|
||||||
if (node instanceof LeafNode) {
|
if (node instanceof LeafNode) {
|
||||||
|
@ -519,6 +538,19 @@ public class TreeInferenceModel implements InferenceModel {
|
||||||
public long ramBytesUsed() {
|
public long ramBytesUsed() {
|
||||||
return SHALLOW_SIZE;
|
return SHALLOW_SIZE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "InnerNode{" +
|
||||||
|
"operator=" + operator +
|
||||||
|
", threshold=" + threshold +
|
||||||
|
", splitFeature=" + splitFeature +
|
||||||
|
", defaultLeft=" + defaultLeft +
|
||||||
|
", leftChild=" + leftChild +
|
||||||
|
", rightChild=" + rightChild +
|
||||||
|
", numberSamples=" + numberSamples +
|
||||||
|
'}';
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class LeafNode extends Node {
|
public static class LeafNode extends Node {
|
||||||
|
@ -544,5 +576,13 @@ public class TreeInferenceModel implements InferenceModel {
|
||||||
public double[] getLeafValue() {
|
public double[] getLeafValue() {
|
||||||
return leafValue;
|
return leafValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "LeafNode{" +
|
||||||
|
"leafValue=" + Arrays.toString(leafValue) +
|
||||||
|
", numberSamples=" + numberSamples +
|
||||||
|
'}';
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ import org.elasticsearch.action.index.IndexAction;
|
||||||
import org.elasticsearch.action.index.IndexRequest;
|
import org.elasticsearch.action.index.IndexRequest;
|
||||||
import org.elasticsearch.action.search.SearchResponse;
|
import org.elasticsearch.action.search.SearchResponse;
|
||||||
import org.elasticsearch.action.support.WriteRequest;
|
import org.elasticsearch.action.support.WriteRequest;
|
||||||
|
import org.elasticsearch.common.settings.Settings;
|
||||||
import org.elasticsearch.common.unit.TimeValue;
|
import org.elasticsearch.common.unit.TimeValue;
|
||||||
import org.elasticsearch.common.xcontent.XContentType;
|
import org.elasticsearch.common.xcontent.XContentType;
|
||||||
import org.elasticsearch.index.query.QueryBuilder;
|
import org.elasticsearch.index.query.QueryBuilder;
|
||||||
|
@ -37,6 +38,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Multi
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
|
import org.junit.Before;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -85,9 +87,25 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
private String destIndex;
|
private String destIndex;
|
||||||
private boolean analysisUsesExistingDestIndex;
|
private boolean analysisUsesExistingDestIndex;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setupLogging() {
|
||||||
|
client().admin().cluster()
|
||||||
|
.prepareUpdateSettings()
|
||||||
|
.setTransientSettings(Settings.builder()
|
||||||
|
.put("logger.org.elasticsearch.xpack.ml.dataframe.inference", "DEBUG")
|
||||||
|
.put("logger.org.elasticsearch.xpack.core.ml.inference", "DEBUG"))
|
||||||
|
.get();
|
||||||
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void cleanup() {
|
public void cleanup() {
|
||||||
cleanUp();
|
cleanUp();
|
||||||
|
client().admin().cluster()
|
||||||
|
.prepareUpdateSettings()
|
||||||
|
.setTransientSettings(Settings.builder()
|
||||||
|
.putNull("logger.org.elasticsearch.xpack.ml.dataframe.inference")
|
||||||
|
.putNull("logger.org.elasticsearch.xpack.core.ml.inference"))
|
||||||
|
.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
||||||
|
|
|
@ -85,6 +85,7 @@ public class InferenceRunner {
|
||||||
TestDocsIterator testDocsIterator = new TestDocsIterator(new OriginSettingClient(client, ClientHelper.ML_ORIGIN), config,
|
TestDocsIterator testDocsIterator = new TestDocsIterator(new OriginSettingClient(client, ClientHelper.ML_ORIGIN), config,
|
||||||
extractedFields);
|
extractedFields);
|
||||||
try (LocalModel localModel = localModelPlainActionFuture.actionGet()) {
|
try (LocalModel localModel = localModelPlainActionFuture.actionGet()) {
|
||||||
|
LOGGER.debug("Loaded inference model [{}]", localModel);
|
||||||
inferTestDocs(localModel, testDocsIterator);
|
inferTestDocs(localModel, testDocsIterator);
|
||||||
}
|
}
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
|
|
|
@ -228,4 +228,22 @@ public class LocalModel implements Closeable {
|
||||||
public void close() {
|
public void close() {
|
||||||
release();
|
release();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "LocalModel{" +
|
||||||
|
"trainedModelDefinition=" + trainedModelDefinition +
|
||||||
|
", modelId='" + modelId + '\'' +
|
||||||
|
", fieldNames=" + fieldNames +
|
||||||
|
", defaultFieldMap=" + defaultFieldMap +
|
||||||
|
", statsAccumulator=" + statsAccumulator +
|
||||||
|
", trainedModelStatsService=" + trainedModelStatsService +
|
||||||
|
", persistenceQuotient=" + persistenceQuotient +
|
||||||
|
", currentInferenceCount=" + currentInferenceCount +
|
||||||
|
", inferenceConfig=" + inferenceConfig +
|
||||||
|
", licenseLevel=" + licenseLevel +
|
||||||
|
", trainedModelCircuitBreaker=" + trainedModelCircuitBreaker +
|
||||||
|
", referenceCount=" + referenceCount +
|
||||||
|
'}';
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue