diff --git a/core/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java b/core/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java index bfa6b7d6577..96f62b44229 100644 --- a/core/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java +++ b/core/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java @@ -511,7 +511,9 @@ public class BalancedShardsAllocator extends AbstractComponent implements Shards continue; } RoutingNode target = routingNodes.node(currentNode.getNodeId()); - Decision decision = allocation.deciders().canAllocate(shard, target, allocation); + Decision allocationDecision = allocation.deciders().canAllocate(shard, target, allocation); + Decision rebalanceDecision = allocation.deciders().canRebalance(shard, allocation); + Decision decision = new Decision.Multi().add(allocationDecision).add(rebalanceDecision); if (decision.type() == Type.YES) { // TODO maybe we can respect throttling here too? sourceNode.removeShard(shard); ShardRouting targetRelocatingShard = routingNodes.relocate(shard, target.nodeId(), allocation.clusterInfo().getShardSize(shard, ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE)); diff --git a/core/src/test/java/org/elasticsearch/cluster/routing/allocation/FilterRoutingTests.java b/core/src/test/java/org/elasticsearch/cluster/routing/allocation/FilterRoutingTests.java index f7f694d5a40..43aee722c59 100644 --- a/core/src/test/java/org/elasticsearch/cluster/routing/allocation/FilterRoutingTests.java +++ b/core/src/test/java/org/elasticsearch/cluster/routing/allocation/FilterRoutingTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.Version; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetaData; import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.routing.RoutingTable; import org.elasticsearch.cluster.routing.ShardRouting; @@ -36,6 +37,7 @@ import java.util.List; import static java.util.Collections.singletonMap; import static org.elasticsearch.cluster.routing.ShardRoutingState.INITIALIZING; +import static org.elasticsearch.cluster.routing.ShardRoutingState.STARTED; import static org.elasticsearch.common.settings.Settings.settingsBuilder; import static org.hamcrest.Matchers.equalTo; @@ -162,4 +164,69 @@ public class FilterRoutingTests extends ESAllocationTestCase { assertThat(startedShard.currentNodeId(), Matchers.anyOf(equalTo("node1"), equalTo("node4"))); } } + + public void testRebalanceAfterShardsCannotRemainOnNode() { + AllocationService strategy = createAllocationService(settingsBuilder().build()); + + logger.info("Building initial routing table"); + MetaData metaData = MetaData.builder() + .put(IndexMetaData.builder("test1").settings(settings(Version.CURRENT)).numberOfShards(2).numberOfReplicas(0)) + .put(IndexMetaData.builder("test2").settings(settings(Version.CURRENT)).numberOfShards(2).numberOfReplicas(0)) + .build(); + + RoutingTable routingTable = RoutingTable.builder() + .addAsNew(metaData.index("test1")) + .addAsNew(metaData.index("test2")) + .build(); + + ClusterState clusterState = ClusterState.builder(org.elasticsearch.cluster.ClusterName.DEFAULT).metaData(metaData).routingTable(routingTable).build(); + + logger.info("--> adding two nodes and performing rerouting"); + DiscoveryNode node1 = newNode("node1", singletonMap("tag1", "value1")); + DiscoveryNode node2 = newNode("node2", singletonMap("tag1", "value2")); + clusterState = ClusterState.builder(clusterState).nodes(DiscoveryNodes.builder().put(node1).put(node2)).build(); + routingTable = strategy.reroute(clusterState).routingTable(); + clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build(); + assertThat(clusterState.getRoutingNodes().node(node1.getId()).numberOfShardsWithState(INITIALIZING), equalTo(2)); + assertThat(clusterState.getRoutingNodes().node(node2.getId()).numberOfShardsWithState(INITIALIZING), equalTo(2)); + + logger.info("--> start the shards (only primaries)"); + routingTable = strategy.applyStartedShards(clusterState, clusterState.getRoutingNodes().shardsWithState(INITIALIZING)).routingTable(); + clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build(); + + logger.info("--> make sure all shards are started"); + assertThat(clusterState.getRoutingNodes().shardsWithState(STARTED).size(), equalTo(4)); + + logger.info("--> disable allocation for node1 and reroute"); + strategy = createAllocationService(settingsBuilder() + .put("cluster.routing.allocation.cluster_concurrent_rebalance", "1") + .put("cluster.routing.allocation.exclude.tag1", "value1") + .build()); + + logger.info("--> move shards from node1 to node2"); + routingTable = strategy.reroute(clusterState).routingTable(); + clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build(); + logger.info("--> check that concurrent rebalance only allows 1 shard to move"); + assertThat(clusterState.getRoutingNodes().node(node1.getId()).numberOfShardsWithState(STARTED), equalTo(1)); + assertThat(clusterState.getRoutingNodes().node(node2.getId()).numberOfShardsWithState(INITIALIZING), equalTo(1)); + assertThat(clusterState.getRoutingNodes().node(node2.getId()).numberOfShardsWithState(STARTED), equalTo(2)); + + logger.info("--> start the shards (only primaries)"); + routingTable = strategy.applyStartedShards(clusterState, clusterState.getRoutingNodes().shardsWithState(INITIALIZING)).routingTable(); + clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build(); + + logger.info("--> move second shard from node1 to node2"); + routingTable = strategy.reroute(clusterState).routingTable(); + clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build(); + assertThat(clusterState.getRoutingNodes().node(node2.getId()).numberOfShardsWithState(INITIALIZING), equalTo(1)); + assertThat(clusterState.getRoutingNodes().node(node2.getId()).numberOfShardsWithState(STARTED), equalTo(3)); + + logger.info("--> start the shards (only primaries)"); + routingTable = strategy.applyStartedShards(clusterState, clusterState.getRoutingNodes().shardsWithState(INITIALIZING)).routingTable(); + clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build(); + + routingTable = strategy.reroute(clusterState).routingTable(); + clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build(); + assertThat(clusterState.getRoutingNodes().node(node2.getId()).numberOfShardsWithState(STARTED), equalTo(4)); + } }