[7.x] Fix non-deterministic behaviour in ModelLoadingServiceTests (#55008) (#55213)

This commit is contained in:
David Kyle 2020-04-15 11:09:12 +01:00 committed by GitHub
parent 0f51934bcf
commit bdf0eab78d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 16 deletions

View File

@ -353,9 +353,9 @@ public class ModelLoadingService implements ClusterStateListener {
logger.trace(() -> new ParameterizedMessage("[{}] {}", modelId, msg.get().getFormattedMessage())); logger.trace(() -> new ParameterizedMessage("[{}] {}", modelId, msg.get().getFormattedMessage()));
return; return;
} }
auditor.warning(modelId, msg.get().getFormattedMessage()); auditor.info(modelId, msg.get().getFormattedMessage());
shouldNotAudit.add(modelId); shouldNotAudit.add(modelId);
logger.warn("[{}] {}", modelId, msg.get().getFormattedMessage()); logger.info("[{}] {}", modelId, msg.get().getFormattedMessage());
} }
private void loadModels(Set<String> modelIds) { private void loadModels(Set<String> modelIds) {

View File

@ -36,9 +36,9 @@ import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
@ -46,6 +46,7 @@ import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.mockito.ArgumentMatcher;
import org.mockito.Mockito; import org.mockito.Mockito;
import java.io.IOException; import java.io.IOException;
@ -61,6 +62,7 @@ import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.argThat;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.atMost; import static org.mockito.Mockito.atMost;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
@ -148,8 +150,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
String model1 = "test-cached-limit-load-model-1"; String model1 = "test-cached-limit-load-model-1";
String model2 = "test-cached-limit-load-model-2"; String model2 = "test-cached-limit-load-model-2";
String model3 = "test-cached-limit-load-model-3"; String model3 = "test-cached-limit-load-model-3";
String[] modelIds = new String[]{model1, model2, model3};
withTrainedModel(model1, 10L); withTrainedModel(model1, 10L);
withTrainedModel(model2, 5L); withTrainedModel(model2, 6L);
withTrainedModel(model3, 15L); withTrainedModel(model3, 15L);
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
@ -163,15 +166,15 @@ public class ModelLoadingServiceTests extends ESTestCase {
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3)); modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
// Should have been loaded from the cluster change event // Should have been loaded from the cluster change event but it is unknown in what order
// Verify that we have at least loaded all three so that evictions occur in the following loop // the loading occurred or which models are currently in the cache due to evictions.
// Verify that we have at least loaded all three
assertBusy(() -> { assertBusy(() -> {
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any()); verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any()); verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any()); verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
}); });
String[] modelIds = new String[]{model1, model2, model3};
for(int i = 0; i < 10; i++) { for(int i = 0; i < 10; i++) {
// Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load) // Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load)
String model = modelIds[i%2]; String model = modelIds[i%2];
@ -180,28 +183,55 @@ public class ModelLoadingServiceTests extends ESTestCase {
assertThat(future.get(), is(not(nullValue()))); assertThat(future.get(), is(not(nullValue())));
} }
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model1), eq(true), any()); verify(trainedModelProvider, times(2)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model2), eq(true), any()); verify(trainedModelProvider, times(2)).getTrainedModel(eq(model2), eq(true), any());
// Only loaded requested once on the initial load from the change event // Only loaded requested once on the initial load from the change event
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any()); verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
// Load model 3, should invalidate 1 // model 3 has been loaded and evicted exactly once
verify(trainedModelStatsService, times(1)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
@Override
public boolean matches(final Object o) {
return ((InferenceStats)o).getModelId().equals(model3);
}
}));
// Load model 3, should invalidate 1 and 2
for(int i = 0; i < 10; i++) { for(int i = 0; i < 10; i++) {
PlainActionFuture<Model<? extends InferenceConfig>> future3 = new PlainActionFuture<>(); PlainActionFuture<Model<? extends InferenceConfig>> future3 = new PlainActionFuture<>();
modelLoadingService.getModel(model3, future3); modelLoadingService.getModel(model3, future3);
assertThat(future3.get(), is(not(nullValue()))); assertThat(future3.get(), is(not(nullValue())));
} }
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model3), eq(true), any()); verify(trainedModelProvider, times(2)).getTrainedModel(eq(model3), eq(true), any());
// Load model 1, should invalidate 2 verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
@Override
public boolean matches(final Object o) {
return ((InferenceStats)o).getModelId().equals(model1);
}
}));
verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
@Override
public boolean matches(final Object o) {
return ((InferenceStats)o).getModelId().equals(model2);
}
}));
// Load model 1, should invalidate 3
for(int i = 0; i < 10; i++) { for(int i = 0; i < 10; i++) {
PlainActionFuture<Model<? extends InferenceConfig>> future1 = new PlainActionFuture<>(); PlainActionFuture<Model<? extends InferenceConfig>> future1 = new PlainActionFuture<>();
modelLoadingService.getModel(model1, future1); modelLoadingService.getModel(model1, future1);
assertThat(future1.get(), is(not(nullValue()))); assertThat(future1.get(), is(not(nullValue())));
} }
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any()); verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelStatsService, times(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
@Override
public boolean matches(final Object o) {
return ((InferenceStats)o).getModelId().equals(model3);
}
}));
// Load model 2, should invalidate 3 // Load model 2
for(int i = 0; i < 10; i++) { for(int i = 0; i < 10; i++) {
PlainActionFuture<Model<? extends InferenceConfig>> future2 = new PlainActionFuture<>(); PlainActionFuture<Model<? extends InferenceConfig>> future2 = new PlainActionFuture<>();
modelLoadingService.getModel(model2, future2); modelLoadingService.getModel(model2, future2);
@ -209,7 +239,6 @@ public class ModelLoadingServiceTests extends ESTestCase {
} }
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any()); verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any());
// Test invalidate cache for model3 // Test invalidate cache for model3
// Now both model 1 and 2 should fit in cache without issues // Now both model 1 and 2 should fit in cache without issues
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2)); modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2));
@ -222,7 +251,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any()); verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any()); verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any());
verify(trainedModelProvider, Mockito.atLeast(4)).getTrainedModel(eq(model3), eq(true), any()); verify(trainedModelProvider, Mockito.atLeast(5)).getTrainedModel(eq(model3), eq(true), any());
verify(trainedModelProvider, atMost(5)).getTrainedModel(eq(model3), eq(true), any()); verify(trainedModelProvider, atMost(5)).getTrainedModel(eq(model3), eq(true), any());
} }