[7.x] Fix non-deterministic behaviour in ModelLoadingServiceTests (#55008) (#55213)

This commit is contained in:
David Kyle 2020-04-15 11:09:12 +01:00 committed by GitHub
parent 0f51934bcf
commit bdf0eab78d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 16 deletions

View File

@ -320,7 +320,7 @@ public class ModelLoadingService implements ClusterStateListener {
// Remove all that are still referenced, i.e. the intersection of allReferencedModelKeys and referencedModels
allReferencedModelKeys.removeAll(referencedModels);
referencedModels.addAll(allReferencedModelKeys);
// Populate loadingListeners key so we know that we are currently loading the model
for (String modelId : allReferencedModelKeys) {
loadingListeners.put(modelId, new ArrayDeque<>());
@ -353,9 +353,9 @@ public class ModelLoadingService implements ClusterStateListener {
logger.trace(() -> new ParameterizedMessage("[{}] {}", modelId, msg.get().getFormattedMessage()));
return;
}
auditor.warning(modelId, msg.get().getFormattedMessage());
auditor.info(modelId, msg.get().getFormattedMessage());
shouldNotAudit.add(modelId);
logger.warn("[{}] {}", modelId, msg.get().getFormattedMessage());
logger.info("[{}] {}", modelId, msg.get().getFormattedMessage());
}
private void loadModels(Set<String> modelIds) {

View File

@ -36,9 +36,9 @@ import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
@ -46,6 +46,7 @@ import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentMatcher;
import org.mockito.Mockito;
import java.io.IOException;
@ -61,6 +62,7 @@ import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.argThat;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.atMost;
import static org.mockito.Mockito.doAnswer;
@ -148,8 +150,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
String model1 = "test-cached-limit-load-model-1";
String model2 = "test-cached-limit-load-model-2";
String model3 = "test-cached-limit-load-model-3";
String[] modelIds = new String[]{model1, model2, model3};
withTrainedModel(model1, 10L);
withTrainedModel(model2, 5L);
withTrainedModel(model2, 6L);
withTrainedModel(model3, 15L);
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
@ -163,15 +166,15 @@ public class ModelLoadingServiceTests extends ESTestCase {
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
// Should have been loaded from the cluster change event
// Verify that we have at least loaded all three so that evictions occur in the following loop
// Should have been loaded from the cluster change event but it is unknown in what order
// the loading occurred or which models are currently in the cache due to evictions.
// Verify that we have at least loaded all three
assertBusy(() -> {
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
});
String[] modelIds = new String[]{model1, model2, model3};
for(int i = 0; i < 10; i++) {
// Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load)
String model = modelIds[i%2];
@ -180,28 +183,55 @@ public class ModelLoadingServiceTests extends ESTestCase {
assertThat(future.get(), is(not(nullValue())));
}
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model2), eq(true), any());
verify(trainedModelProvider, times(2)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelProvider, times(2)).getTrainedModel(eq(model2), eq(true), any());
// Only loaded requested once on the initial load from the change event
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
// Load model 3, should invalidate 1
// model 3 has been loaded and evicted exactly once
verify(trainedModelStatsService, times(1)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
@Override
public boolean matches(final Object o) {
return ((InferenceStats)o).getModelId().equals(model3);
}
}));
// Load model 3, should invalidate 1 and 2
for(int i = 0; i < 10; i++) {
PlainActionFuture<Model<? extends InferenceConfig>> future3 = new PlainActionFuture<>();
modelLoadingService.getModel(model3, future3);
assertThat(future3.get(), is(not(nullValue())));
}
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model3), eq(true), any());
verify(trainedModelProvider, times(2)).getTrainedModel(eq(model3), eq(true), any());
// Load model 1, should invalidate 2
verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
@Override
public boolean matches(final Object o) {
return ((InferenceStats)o).getModelId().equals(model1);
}
}));
verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
@Override
public boolean matches(final Object o) {
return ((InferenceStats)o).getModelId().equals(model2);
}
}));
// Load model 1, should invalidate 3
for(int i = 0; i < 10; i++) {
PlainActionFuture<Model<? extends InferenceConfig>> future1 = new PlainActionFuture<>();
modelLoadingService.getModel(model1, future1);
assertThat(future1.get(), is(not(nullValue())));
}
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelStatsService, times(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
@Override
public boolean matches(final Object o) {
return ((InferenceStats)o).getModelId().equals(model3);
}
}));
// Load model 2, should invalidate 3
// Load model 2
for(int i = 0; i < 10; i++) {
PlainActionFuture<Model<? extends InferenceConfig>> future2 = new PlainActionFuture<>();
modelLoadingService.getModel(model2, future2);
@ -209,7 +239,6 @@ public class ModelLoadingServiceTests extends ESTestCase {
}
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any());
// Test invalidate cache for model3
// Now both model 1 and 2 should fit in cache without issues
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2));
@ -222,7 +251,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any());
verify(trainedModelProvider, Mockito.atLeast(4)).getTrainedModel(eq(model3), eq(true), any());
verify(trainedModelProvider, Mockito.atLeast(5)).getTrainedModel(eq(model3), eq(true), any());
verify(trainedModelProvider, atMost(5)).getTrainedModel(eq(model3), eq(true), any());
}