This commit is contained in:
parent
0f51934bcf
commit
bdf0eab78d
|
@ -353,9 +353,9 @@ public class ModelLoadingService implements ClusterStateListener {
|
||||||
logger.trace(() -> new ParameterizedMessage("[{}] {}", modelId, msg.get().getFormattedMessage()));
|
logger.trace(() -> new ParameterizedMessage("[{}] {}", modelId, msg.get().getFormattedMessage()));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auditor.warning(modelId, msg.get().getFormattedMessage());
|
auditor.info(modelId, msg.get().getFormattedMessage());
|
||||||
shouldNotAudit.add(modelId);
|
shouldNotAudit.add(modelId);
|
||||||
logger.warn("[{}] {}", modelId, msg.get().getFormattedMessage());
|
logger.info("[{}] {}", modelId, msg.get().getFormattedMessage());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void loadModels(Set<String> modelIds) {
|
private void loadModels(Set<String> modelIds) {
|
||||||
|
|
|
@ -36,9 +36,9 @@ import org.elasticsearch.threadpool.ThreadPool;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
|
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
|
||||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||||
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
|
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
|
||||||
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
||||||
|
@ -46,6 +46,7 @@ import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||||
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
|
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
import org.mockito.ArgumentMatcher;
|
||||||
import org.mockito.Mockito;
|
import org.mockito.Mockito;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -61,6 +62,7 @@ import static org.hamcrest.Matchers.is;
|
||||||
import static org.hamcrest.Matchers.not;
|
import static org.hamcrest.Matchers.not;
|
||||||
import static org.hamcrest.Matchers.nullValue;
|
import static org.hamcrest.Matchers.nullValue;
|
||||||
import static org.mockito.Matchers.any;
|
import static org.mockito.Matchers.any;
|
||||||
|
import static org.mockito.Matchers.argThat;
|
||||||
import static org.mockito.Matchers.eq;
|
import static org.mockito.Matchers.eq;
|
||||||
import static org.mockito.Mockito.atMost;
|
import static org.mockito.Mockito.atMost;
|
||||||
import static org.mockito.Mockito.doAnswer;
|
import static org.mockito.Mockito.doAnswer;
|
||||||
|
@ -148,8 +150,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
||||||
String model1 = "test-cached-limit-load-model-1";
|
String model1 = "test-cached-limit-load-model-1";
|
||||||
String model2 = "test-cached-limit-load-model-2";
|
String model2 = "test-cached-limit-load-model-2";
|
||||||
String model3 = "test-cached-limit-load-model-3";
|
String model3 = "test-cached-limit-load-model-3";
|
||||||
|
String[] modelIds = new String[]{model1, model2, model3};
|
||||||
withTrainedModel(model1, 10L);
|
withTrainedModel(model1, 10L);
|
||||||
withTrainedModel(model2, 5L);
|
withTrainedModel(model2, 6L);
|
||||||
withTrainedModel(model3, 15L);
|
withTrainedModel(model3, 15L);
|
||||||
|
|
||||||
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
|
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
|
||||||
|
@ -163,15 +166,15 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
||||||
|
|
||||||
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
|
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
|
||||||
|
|
||||||
// Should have been loaded from the cluster change event
|
// Should have been loaded from the cluster change event but it is unknown in what order
|
||||||
// Verify that we have at least loaded all three so that evictions occur in the following loop
|
// the loading occurred or which models are currently in the cache due to evictions.
|
||||||
|
// Verify that we have at least loaded all three
|
||||||
assertBusy(() -> {
|
assertBusy(() -> {
|
||||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any());
|
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any());
|
||||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any());
|
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any());
|
||||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
|
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
|
||||||
});
|
});
|
||||||
|
|
||||||
String[] modelIds = new String[]{model1, model2, model3};
|
|
||||||
for(int i = 0; i < 10; i++) {
|
for(int i = 0; i < 10; i++) {
|
||||||
// Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load)
|
// Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load)
|
||||||
String model = modelIds[i%2];
|
String model = modelIds[i%2];
|
||||||
|
@ -180,28 +183,55 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
||||||
assertThat(future.get(), is(not(nullValue())));
|
assertThat(future.get(), is(not(nullValue())));
|
||||||
}
|
}
|
||||||
|
|
||||||
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model1), eq(true), any());
|
verify(trainedModelProvider, times(2)).getTrainedModel(eq(model1), eq(true), any());
|
||||||
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model2), eq(true), any());
|
verify(trainedModelProvider, times(2)).getTrainedModel(eq(model2), eq(true), any());
|
||||||
// Only loaded requested once on the initial load from the change event
|
// Only loaded requested once on the initial load from the change event
|
||||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
|
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
|
||||||
|
|
||||||
// Load model 3, should invalidate 1
|
// model 3 has been loaded and evicted exactly once
|
||||||
|
verify(trainedModelStatsService, times(1)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
|
||||||
|
@Override
|
||||||
|
public boolean matches(final Object o) {
|
||||||
|
return ((InferenceStats)o).getModelId().equals(model3);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Load model 3, should invalidate 1 and 2
|
||||||
for(int i = 0; i < 10; i++) {
|
for(int i = 0; i < 10; i++) {
|
||||||
PlainActionFuture<Model<? extends InferenceConfig>> future3 = new PlainActionFuture<>();
|
PlainActionFuture<Model<? extends InferenceConfig>> future3 = new PlainActionFuture<>();
|
||||||
modelLoadingService.getModel(model3, future3);
|
modelLoadingService.getModel(model3, future3);
|
||||||
assertThat(future3.get(), is(not(nullValue())));
|
assertThat(future3.get(), is(not(nullValue())));
|
||||||
}
|
}
|
||||||
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model3), eq(true), any());
|
verify(trainedModelProvider, times(2)).getTrainedModel(eq(model3), eq(true), any());
|
||||||
|
|
||||||
// Load model 1, should invalidate 2
|
verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
|
||||||
|
@Override
|
||||||
|
public boolean matches(final Object o) {
|
||||||
|
return ((InferenceStats)o).getModelId().equals(model1);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
|
||||||
|
@Override
|
||||||
|
public boolean matches(final Object o) {
|
||||||
|
return ((InferenceStats)o).getModelId().equals(model2);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Load model 1, should invalidate 3
|
||||||
for(int i = 0; i < 10; i++) {
|
for(int i = 0; i < 10; i++) {
|
||||||
PlainActionFuture<Model<? extends InferenceConfig>> future1 = new PlainActionFuture<>();
|
PlainActionFuture<Model<? extends InferenceConfig>> future1 = new PlainActionFuture<>();
|
||||||
modelLoadingService.getModel(model1, future1);
|
modelLoadingService.getModel(model1, future1);
|
||||||
assertThat(future1.get(), is(not(nullValue())));
|
assertThat(future1.get(), is(not(nullValue())));
|
||||||
}
|
}
|
||||||
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any());
|
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any());
|
||||||
|
verify(trainedModelStatsService, times(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
|
||||||
|
@Override
|
||||||
|
public boolean matches(final Object o) {
|
||||||
|
return ((InferenceStats)o).getModelId().equals(model3);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
|
||||||
// Load model 2, should invalidate 3
|
// Load model 2
|
||||||
for(int i = 0; i < 10; i++) {
|
for(int i = 0; i < 10; i++) {
|
||||||
PlainActionFuture<Model<? extends InferenceConfig>> future2 = new PlainActionFuture<>();
|
PlainActionFuture<Model<? extends InferenceConfig>> future2 = new PlainActionFuture<>();
|
||||||
modelLoadingService.getModel(model2, future2);
|
modelLoadingService.getModel(model2, future2);
|
||||||
|
@ -209,7 +239,6 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any());
|
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any());
|
||||||
|
|
||||||
|
|
||||||
// Test invalidate cache for model3
|
// Test invalidate cache for model3
|
||||||
// Now both model 1 and 2 should fit in cache without issues
|
// Now both model 1 and 2 should fit in cache without issues
|
||||||
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2));
|
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2));
|
||||||
|
@ -222,7 +251,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
||||||
|
|
||||||
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any());
|
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any());
|
||||||
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any());
|
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any());
|
||||||
verify(trainedModelProvider, Mockito.atLeast(4)).getTrainedModel(eq(model3), eq(true), any());
|
verify(trainedModelProvider, Mockito.atLeast(5)).getTrainedModel(eq(model3), eq(true), any());
|
||||||
verify(trainedModelProvider, atMost(5)).getTrainedModel(eq(model3), eq(true), any());
|
verify(trainedModelProvider, atMost(5)).getTrainedModel(eq(model3), eq(true), any());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue