This commit is contained in:
parent
0f51934bcf
commit
bdf0eab78d
|
@ -320,7 +320,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
// Remove all that are still referenced, i.e. the intersection of allReferencedModelKeys and referencedModels
|
||||
allReferencedModelKeys.removeAll(referencedModels);
|
||||
referencedModels.addAll(allReferencedModelKeys);
|
||||
|
||||
|
||||
// Populate loadingListeners key so we know that we are currently loading the model
|
||||
for (String modelId : allReferencedModelKeys) {
|
||||
loadingListeners.put(modelId, new ArrayDeque<>());
|
||||
|
@ -353,9 +353,9 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
logger.trace(() -> new ParameterizedMessage("[{}] {}", modelId, msg.get().getFormattedMessage()));
|
||||
return;
|
||||
}
|
||||
auditor.warning(modelId, msg.get().getFormattedMessage());
|
||||
auditor.info(modelId, msg.get().getFormattedMessage());
|
||||
shouldNotAudit.add(modelId);
|
||||
logger.warn("[{}] {}", modelId, msg.get().getFormattedMessage());
|
||||
logger.info("[{}] {}", modelId, msg.get().getFormattedMessage());
|
||||
}
|
||||
|
||||
private void loadModels(Set<String> modelIds) {
|
||||
|
|
|
@ -36,9 +36,9 @@ import org.elasticsearch.threadpool.ThreadPool;
|
|||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
|
||||
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
||||
|
@ -46,6 +46,7 @@ import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
|||
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.mockito.ArgumentMatcher;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -61,6 +62,7 @@ import static org.hamcrest.Matchers.is;
|
|||
import static org.hamcrest.Matchers.not;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.argThat;
|
||||
import static org.mockito.Matchers.eq;
|
||||
import static org.mockito.Mockito.atMost;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
|
@ -148,8 +150,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
String model1 = "test-cached-limit-load-model-1";
|
||||
String model2 = "test-cached-limit-load-model-2";
|
||||
String model3 = "test-cached-limit-load-model-3";
|
||||
String[] modelIds = new String[]{model1, model2, model3};
|
||||
withTrainedModel(model1, 10L);
|
||||
withTrainedModel(model2, 5L);
|
||||
withTrainedModel(model2, 6L);
|
||||
withTrainedModel(model3, 15L);
|
||||
|
||||
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
|
||||
|
@ -163,15 +166,15 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
|
||||
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
|
||||
|
||||
// Should have been loaded from the cluster change event
|
||||
// Verify that we have at least loaded all three so that evictions occur in the following loop
|
||||
// Should have been loaded from the cluster change event but it is unknown in what order
|
||||
// the loading occurred or which models are currently in the cache due to evictions.
|
||||
// Verify that we have at least loaded all three
|
||||
assertBusy(() -> {
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any());
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any());
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
|
||||
});
|
||||
|
||||
String[] modelIds = new String[]{model1, model2, model3};
|
||||
for(int i = 0; i < 10; i++) {
|
||||
// Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load)
|
||||
String model = modelIds[i%2];
|
||||
|
@ -180,28 +183,55 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
assertThat(future.get(), is(not(nullValue())));
|
||||
}
|
||||
|
||||
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model1), eq(true), any());
|
||||
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model2), eq(true), any());
|
||||
verify(trainedModelProvider, times(2)).getTrainedModel(eq(model1), eq(true), any());
|
||||
verify(trainedModelProvider, times(2)).getTrainedModel(eq(model2), eq(true), any());
|
||||
// Only loaded requested once on the initial load from the change event
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
|
||||
|
||||
// Load model 3, should invalidate 1
|
||||
// model 3 has been loaded and evicted exactly once
|
||||
verify(trainedModelStatsService, times(1)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
|
||||
@Override
|
||||
public boolean matches(final Object o) {
|
||||
return ((InferenceStats)o).getModelId().equals(model3);
|
||||
}
|
||||
}));
|
||||
|
||||
// Load model 3, should invalidate 1 and 2
|
||||
for(int i = 0; i < 10; i++) {
|
||||
PlainActionFuture<Model<? extends InferenceConfig>> future3 = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model3, future3);
|
||||
assertThat(future3.get(), is(not(nullValue())));
|
||||
}
|
||||
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model3), eq(true), any());
|
||||
verify(trainedModelProvider, times(2)).getTrainedModel(eq(model3), eq(true), any());
|
||||
|
||||
// Load model 1, should invalidate 2
|
||||
verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
|
||||
@Override
|
||||
public boolean matches(final Object o) {
|
||||
return ((InferenceStats)o).getModelId().equals(model1);
|
||||
}
|
||||
}));
|
||||
verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
|
||||
@Override
|
||||
public boolean matches(final Object o) {
|
||||
return ((InferenceStats)o).getModelId().equals(model2);
|
||||
}
|
||||
}));
|
||||
|
||||
// Load model 1, should invalidate 3
|
||||
for(int i = 0; i < 10; i++) {
|
||||
PlainActionFuture<Model<? extends InferenceConfig>> future1 = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model1, future1);
|
||||
assertThat(future1.get(), is(not(nullValue())));
|
||||
}
|
||||
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any());
|
||||
verify(trainedModelStatsService, times(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
|
||||
@Override
|
||||
public boolean matches(final Object o) {
|
||||
return ((InferenceStats)o).getModelId().equals(model3);
|
||||
}
|
||||
}));
|
||||
|
||||
// Load model 2, should invalidate 3
|
||||
// Load model 2
|
||||
for(int i = 0; i < 10; i++) {
|
||||
PlainActionFuture<Model<? extends InferenceConfig>> future2 = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model2, future2);
|
||||
|
@ -209,7 +239,6 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
}
|
||||
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any());
|
||||
|
||||
|
||||
// Test invalidate cache for model3
|
||||
// Now both model 1 and 2 should fit in cache without issues
|
||||
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2));
|
||||
|
@ -222,7 +251,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
|
||||
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any());
|
||||
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any());
|
||||
verify(trainedModelProvider, Mockito.atLeast(4)).getTrainedModel(eq(model3), eq(true), any());
|
||||
verify(trainedModelProvider, Mockito.atLeast(5)).getTrainedModel(eq(model3), eq(true), any());
|
||||
verify(trainedModelProvider, atMost(5)).getTrainedModel(eq(model3), eq(true), any());
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue