parent
2a4fd8329b
commit
66b3e89482
|
@ -6,8 +6,11 @@
|
|||
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
@ -25,6 +28,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Leniently
|
|||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
|
@ -49,6 +53,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.En
|
|||
public class EnsembleInferenceModel implements InferenceModel {
|
||||
|
||||
public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class);
|
||||
private static final Logger LOGGER = LogManager.getLogger(EnsembleInferenceModel.class);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<EnsembleInferenceModel, Void> PARSER = new ConstructingObjectParser<>(
|
||||
|
@ -136,6 +141,8 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|||
if (preparedForInference == false) {
|
||||
throw ExceptionsHelper.serverError("model is not prepared for inference");
|
||||
}
|
||||
LOGGER.debug("Inference called with feature names [{}]",
|
||||
featureNames == null ? "<null>" : Strings.arrayToCommaDelimitedString(featureNames));
|
||||
assert featureNames != null && featureNames.length > 0;
|
||||
double[][] inferenceResults = new double[this.models.size()][];
|
||||
double[][] featureInfluence = new double[features.length][];
|
||||
|
@ -237,12 +244,14 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|||
|
||||
@Override
|
||||
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
|
||||
LOGGER.debug("rewriting features {}", newFeatureIndexMapping);
|
||||
if (preparedForInference) {
|
||||
return;
|
||||
}
|
||||
preparedForInference = true;
|
||||
if (newFeatureIndexMapping == null || newFeatureIndexMapping.isEmpty()) {
|
||||
Set<String> referencedFeatures = subModelFeatures();
|
||||
LOGGER.debug("detected submodel feature names {}", referencedFeatures);
|
||||
int newFeatureIndex = 0;
|
||||
newFeatureIndexMapping = new HashMap<>();
|
||||
this.featureNames = new String[referencedFeatures.size()];
|
||||
|
@ -301,4 +310,16 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|||
return classificationWeights;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "EnsembleInferenceModel{" +
|
||||
"featureNames=" + Arrays.toString(featureNames) +
|
||||
", models=" + models +
|
||||
", outputAggregator=" + outputAggregator +
|
||||
", targetType=" + targetType +
|
||||
", classificationLabels=" + classificationLabels +
|
||||
", classificationWeights=" + Arrays.toString(classificationWeights) +
|
||||
", preparedForInference=" + preparedForInference +
|
||||
'}';
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,8 @@
|
|||
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.Numbers;
|
||||
|
@ -28,6 +30,7 @@ import org.elasticsearch.xpack.core.ml.inference.utils.Statistics;
|
|||
import org.elasticsearch.xpack.core.ml.job.config.Operator;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
@ -56,6 +59,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNo
|
|||
|
||||
public class TreeInferenceModel implements InferenceModel {
|
||||
|
||||
private static final Logger LOGGER = LogManager.getLogger(TreeInferenceModel.class);
|
||||
public static final long SHALLOW_SIZE = shallowSizeOfInstance(TreeInferenceModel.class);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
|
@ -304,6 +308,7 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
|
||||
@Override
|
||||
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
|
||||
LOGGER.debug("rewriting features {}", newFeatureIndexMapping);
|
||||
if (preparedForInference) {
|
||||
return;
|
||||
}
|
||||
|
@ -358,6 +363,20 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
return nodes;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "TreeInferenceModel{" +
|
||||
"nodes=" + Arrays.toString(nodes) +
|
||||
", featureNames=" + Arrays.toString(featureNames) +
|
||||
", targetType=" + targetType +
|
||||
", classificationLabels=" + classificationLabels +
|
||||
", highOrderCategory=" + highOrderCategory +
|
||||
", maxDepth=" + maxDepth +
|
||||
", leafSize=" + leafSize +
|
||||
", preparedForInference=" + preparedForInference +
|
||||
'}';
|
||||
}
|
||||
|
||||
private static int getDepth(Node[] nodes, int nodeIndex, int depth) {
|
||||
Node node = nodes[nodeIndex];
|
||||
if (node instanceof LeafNode) {
|
||||
|
@ -519,6 +538,19 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
public long ramBytesUsed() {
|
||||
return SHALLOW_SIZE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "InnerNode{" +
|
||||
"operator=" + operator +
|
||||
", threshold=" + threshold +
|
||||
", splitFeature=" + splitFeature +
|
||||
", defaultLeft=" + defaultLeft +
|
||||
", leftChild=" + leftChild +
|
||||
", rightChild=" + rightChild +
|
||||
", numberSamples=" + numberSamples +
|
||||
'}';
|
||||
}
|
||||
}
|
||||
|
||||
public static class LeafNode extends Node {
|
||||
|
@ -544,5 +576,13 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
public double[] getLeafValue() {
|
||||
return leafValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "LeafNode{" +
|
||||
"leafValue=" + Arrays.toString(leafValue) +
|
||||
", numberSamples=" + numberSamples +
|
||||
'}';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ import org.elasticsearch.action.index.IndexAction;
|
|||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
|
@ -37,6 +38,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Multi
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
|
@ -85,9 +87,25 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
private String destIndex;
|
||||
private boolean analysisUsesExistingDestIndex;
|
||||
|
||||
@Before
|
||||
public void setupLogging() {
|
||||
client().admin().cluster()
|
||||
.prepareUpdateSettings()
|
||||
.setTransientSettings(Settings.builder()
|
||||
.put("logger.org.elasticsearch.xpack.ml.dataframe.inference", "DEBUG")
|
||||
.put("logger.org.elasticsearch.xpack.core.ml.inference", "DEBUG"))
|
||||
.get();
|
||||
}
|
||||
|
||||
@After
|
||||
public void cleanup() {
|
||||
cleanUp();
|
||||
client().admin().cluster()
|
||||
.prepareUpdateSettings()
|
||||
.setTransientSettings(Settings.builder()
|
||||
.putNull("logger.org.elasticsearch.xpack.ml.dataframe.inference")
|
||||
.putNull("logger.org.elasticsearch.xpack.core.ml.inference"))
|
||||
.get();
|
||||
}
|
||||
|
||||
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
||||
|
|
|
@ -85,6 +85,7 @@ public class InferenceRunner {
|
|||
TestDocsIterator testDocsIterator = new TestDocsIterator(new OriginSettingClient(client, ClientHelper.ML_ORIGIN), config,
|
||||
extractedFields);
|
||||
try (LocalModel localModel = localModelPlainActionFuture.actionGet()) {
|
||||
LOGGER.debug("Loaded inference model [{}]", localModel);
|
||||
inferTestDocs(localModel, testDocsIterator);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
|
|
|
@ -228,4 +228,22 @@ public class LocalModel implements Closeable {
|
|||
public void close() {
|
||||
release();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "LocalModel{" +
|
||||
"trainedModelDefinition=" + trainedModelDefinition +
|
||||
", modelId='" + modelId + '\'' +
|
||||
", fieldNames=" + fieldNames +
|
||||
", defaultFieldMap=" + defaultFieldMap +
|
||||
", statsAccumulator=" + statsAccumulator +
|
||||
", trainedModelStatsService=" + trainedModelStatsService +
|
||||
", persistenceQuotient=" + persistenceQuotient +
|
||||
", currentInferenceCount=" + currentInferenceCount +
|
||||
", inferenceConfig=" + inferenceConfig +
|
||||
", licenseLevel=" + licenseLevel +
|
||||
", trainedModelCircuitBreaker=" + trainedModelCircuitBreaker +
|
||||
", referenceCount=" + referenceCount +
|
||||
'}';
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue