[7.x] Respect ML upgrade mode in TrainedModelStatsService (#61143) (#61187)

When in upgrade mode the ml stats service should not write to the stats index.
This commit is contained in:
David Kyle 2020-08-17 11:09:25 +01:00 committed by GitHub
parent f3e0c60896
commit ba89af544f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 133 additions and 2 deletions

View File

@ -15,6 +15,7 @@ import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.routing.IndexRoutingTable; import org.elasticsearch.cluster.routing.IndexRoutingTable;
@ -29,6 +30,7 @@ import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType; import org.elasticsearch.script.ScriptType;
import org.elasticsearch.threadpool.Scheduler; import org.elasticsearch.threadpool.Scheduler;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.MlMetadata;
import org.elasticsearch.xpack.core.ml.MlStatsIndex; import org.elasticsearch.xpack.core.ml.MlStatsIndex;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.job.messages.Messages;
@ -102,7 +104,12 @@ public class TrainedModelStatsService {
stop(); stop();
} }
}); });
clusterService.addListener((event) -> this.clusterState = event.state()); clusterService.addListener(this::setClusterState);
}
// visible for testing
void setClusterState(ClusterChangedEvent event) {
clusterState = event.state();
} }
/** /**
@ -146,6 +153,13 @@ public class TrainedModelStatsService {
if (clusterState == null || statsQueue.isEmpty() || stopped) { if (clusterState == null || statsQueue.isEmpty() || stopped) {
return; return;
} }
boolean isInUpgradeMode = MlMetadata.getMlMetadata(clusterState).isUpgradeMode();
if (isInUpgradeMode) {
logger.debug("Model stats not persisted as ml upgrade mode is enabled");
return;
}
if (verifyIndicesExistAndPrimaryShardsAreActive(clusterState, indexNameExpressionResolver) == false) { if (verifyIndicesExistAndPrimaryShardsAreActive(clusterState, indexNameExpressionResolver) == false) {
try { try {
logger.debug("About to create the stats index as it does not exist yet"); logger.debug("About to create the stats index as it does not exist yet");
@ -251,5 +265,4 @@ public class TrainedModelStatsService {
} }
return null; return null;
} }
} }

View File

@ -7,6 +7,9 @@
package org.elasticsearch.xpack.ml.inference; package org.elasticsearch.xpack.ml.inference;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.AliasMetadata; import org.elasticsearch.cluster.metadata.AliasMetadata;
@ -19,13 +22,25 @@ import org.elasticsearch.cluster.routing.RecoverySource;
import org.elasticsearch.cluster.routing.RoutingTable; import org.elasticsearch.cluster.routing.RoutingTable;
import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.cluster.routing.UnassignedInfo; import org.elasticsearch.cluster.routing.UnassignedInfo;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.Index; import org.elasticsearch.index.Index;
import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.MlMetadata;
import org.elasticsearch.xpack.core.ml.MlStatsIndex; import org.elasticsearch.xpack.core.ml.MlStatsIndex;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.ml.test.MockOriginSettingClient;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
import java.time.Instant;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
public class TrainedModelStatsServiceTests extends ESTestCase { public class TrainedModelStatsServiceTests extends ESTestCase {
@ -119,6 +134,109 @@ public class TrainedModelStatsServiceTests extends ESTestCase {
} }
} }
public void testUpdateStatsUpgradeMode() {
String aliasName = MlStatsIndex.writeAlias();
String concreteIndex = ".ml-stats-000001";
IndexNameExpressionResolver resolver = new IndexNameExpressionResolver();
// create a valid index routing so persistence will occur
RoutingTable.Builder routingTableBuilder = RoutingTable.builder();
addToRoutingTable(concreteIndex, routingTableBuilder);
RoutingTable routingTable = routingTableBuilder.build();
OriginSettingClient originSettingClient =
MockOriginSettingClient.mockOriginSettingClient(mock(Client.class), "modelstatsservicetests");
ClusterService clusterService = mock(ClusterService.class);
ThreadPool threadPool = mock(ThreadPool.class);
ResultsPersisterService persisterService = mock(ResultsPersisterService.class);
TrainedModelStatsService service = new TrainedModelStatsService(persisterService,
originSettingClient, resolver, clusterService, threadPool);
InferenceStats.Accumulator accumulator = new InferenceStats.Accumulator("testUpdateStatsUpgradeMode", "test-node", 1L);
{
IndexMetadata.Builder indexMetadata = IndexMetadata.builder(concreteIndex)
.putAlias(AliasMetadata.builder(aliasName).isHidden(true).build())
.settings(Settings.builder()
.put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT)
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
);
Metadata.Builder metadata = Metadata.builder().put(indexMetadata);
ClusterState clusterState = ClusterState.builder(new ClusterName("upgrade-mode-test-initial-state"))
.routingTable(routingTable)
.metadata(metadata)
.build();
ClusterChangedEvent change = new ClusterChangedEvent("created-from-test", clusterState, clusterState);
service.setClusterState(change);
// queue some stats to be persisted
service.queueStats(accumulator.currentStats(Instant.now()), false);
service.updateStats();
verify(persisterService, times(1)).bulkIndexWithRetry(any(), any(), any(), any());
}
{
// test with upgrade mode turned on
IndexMetadata.Builder indexMetadata = IndexMetadata.builder(concreteIndex)
.putAlias(AliasMetadata.builder(aliasName).isHidden(true).build())
.settings(Settings.builder()
.put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT)
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
);
// now set the upgrade mode
Metadata.Builder metadata = Metadata.builder()
.put(indexMetadata)
.putCustom(MlMetadata.TYPE, new MlMetadata.Builder().isUpgradeMode(true).build());
ClusterState clusterState = ClusterState.builder(new ClusterName("upgrade-mode-test-upgrade-enabled"))
.routingTable(routingTable)
.metadata(metadata)
.build();
ClusterChangedEvent change = new ClusterChangedEvent("created-from-test", clusterState, clusterState);
service.setClusterState(change);
// queue some stats to be persisted
service.queueStats(accumulator.currentStats(Instant.now()), false);
service.updateStats();
verify(persisterService, times(1)).bulkIndexWithRetry(any(), any(), any(), any());
}
{
// This time turn off upgrade mode
IndexMetadata.Builder indexMetadata = IndexMetadata.builder(concreteIndex)
.putAlias(AliasMetadata.builder(aliasName).isHidden(true).build())
.settings(Settings.builder()
.put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT)
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
);
Metadata.Builder metadata = Metadata.builder()
.put(indexMetadata)
.putCustom(MlMetadata.TYPE, new MlMetadata.Builder().isUpgradeMode(false).build());
ClusterState clusterState = ClusterState.builder(new ClusterName("upgrade-mode-test-upgrade-disabled"))
.routingTable(routingTable)
.metadata(metadata)
.build();
ClusterChangedEvent change = new ClusterChangedEvent("created-from-test", clusterState, clusterState);
service.setClusterState(change);
service.updateStats();
verify(persisterService, times(2)).bulkIndexWithRetry(any(), any(), any(), any());
}
}
private static void addToRoutingTable(String concreteIndex, RoutingTable.Builder routingTable) { private static void addToRoutingTable(String concreteIndex, RoutingTable.Builder routingTable) {
Index index = new Index(concreteIndex, "_uuid"); Index index = new Index(concreteIndex, "_uuid");
ShardId shardId = new ShardId(index, 0); ShardId shardId = new ShardId(index, 0);