diff --git a/solr/core/src/java/org/apache/solr/cloud/Assign.java b/solr/core/src/java/org/apache/solr/cloud/Assign.java index 265e4534f1a..d790e7a09bb 100644 --- a/solr/core/src/java/org/apache/solr/cloud/Assign.java +++ b/solr/core/src/java/org/apache/solr/cloud/Assign.java @@ -16,6 +16,7 @@ */ package org.apache.solr.cloud; +import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -28,6 +29,11 @@ import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; +import org.apache.solr.client.solrj.impl.CloudSolrClient; +import org.apache.solr.client.solrj.impl.SolrClientDataProvider; +import org.apache.solr.client.solrj.impl.ZkClientClusterStateProvider; +import org.apache.solr.cloud.autoscaling.Policy; +import org.apache.solr.cloud.autoscaling.PolicyHelper; import org.apache.solr.cloud.rule.ReplicaAssigner; import org.apache.solr.cloud.rule.Rule; import org.apache.solr.common.SolrException; @@ -35,11 +41,19 @@ import org.apache.solr.common.cloud.ClusterState; import org.apache.solr.common.cloud.DocCollection; import org.apache.solr.common.cloud.Replica; import org.apache.solr.common.cloud.Slice; +import org.apache.solr.common.cloud.ZkNodeProps; +import org.apache.solr.common.cloud.ZkStateReader; import org.apache.solr.common.util.StrUtils; import org.apache.solr.core.CoreContainer; +import org.apache.zookeeper.KeeperException; +import static java.util.Collections.singletonMap; +import static org.apache.solr.cloud.autoscaling.Policy.POLICY; +import static org.apache.solr.common.cloud.ZkStateReader.COLLECTION_PROP; import static org.apache.solr.common.cloud.ZkStateReader.CORE_NAME_PROP; import static org.apache.solr.common.cloud.ZkStateReader.MAX_SHARDS_PER_NODE; +import static org.apache.solr.common.cloud.ZkStateReader.SOLR_AUTOSCALING_CONF_PATH; +import static org.apache.solr.common.params.CommonParams.NAME; public class Assign { @@ -150,7 +164,7 @@ public class Assign { // could be created on live nodes given maxShardsPerNode, Replication factor (if from createShard) etc. public static List getNodesForNewReplicas(ClusterState clusterState, String collectionName, String shard, int numberOfNodes, - Object createNodeSet, CoreContainer cc) { + Object createNodeSet, CoreContainer cc) throws KeeperException, InterruptedException { DocCollection coll = clusterState.getCollection(collectionName); Integer maxShardsPerNode = coll.getInt(MAX_SHARDS_PER_NODE, 1); List createNodeList = null; @@ -179,8 +193,23 @@ public class Assign { } List l = (List) coll.get(DocCollection.RULE); + Map positions = null; if (l != null) { - return getNodesViaRules(clusterState, shard, numberOfNodes, cc, coll, createNodeList, l); + positions = getNodesViaRules(clusterState, shard, numberOfNodes, cc, coll, createNodeList, l); + } + String policyName = coll.getStr(POLICY); + Map autoSalingJson = cc.getZkController().getZkStateReader().getZkClient().getJson(SOLR_AUTOSCALING_CONF_PATH, true); + if (policyName != null || autoSalingJson.get(Policy.CLUSTER_POLICY) != null) { + positions= Assign.getPositionsUsingPolicy(collectionName, Collections.singletonList(shard), numberOfNodes, + policyName, cc.getZkController().getZkStateReader()); + } + + if(positions != null){ + List repCounts = new ArrayList<>(); + for (String s : positions.values()) { + repCounts.add(new ReplicaCount(s)); + } + return repCounts; } ArrayList sortedNodeList = new ArrayList<>(nodeNameVsShardCount.values()); @@ -188,9 +217,30 @@ public class Assign { return sortedNodeList; } + public static Map getPositionsUsingPolicy(String collName, List shardNames, int numReplicas, + String policyName, ZkStateReader zkStateReader) throws KeeperException, InterruptedException { + try (CloudSolrClient csc = new CloudSolrClient.Builder() + .withClusterStateProvider(new ZkClientClusterStateProvider(zkStateReader)) + .build()) { + SolrClientDataProvider clientDataProvider = new SolrClientDataProvider(csc); + Map> locations = PolicyHelper.getReplicaLocations(collName, + zkStateReader.getZkClient().getJson(SOLR_AUTOSCALING_CONF_PATH, true), + clientDataProvider, singletonMap(collName, policyName), shardNames, numReplicas); + Map result = new HashMap<>(); + for (Map.Entry> e : locations.entrySet()) { + List value = e.getValue(); + for (int i = 0; i < value.size(); i++) { + result.put(new ReplicaAssigner.Position(e.getKey(), i, Replica.Type.NRT), value.get(i)); + } + } + return result; + } catch (IOException e) { + throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Error closing CloudSolrClient",e); + } + } - private static List getNodesViaRules(ClusterState clusterState, String shard, int numberOfNodes, - CoreContainer cc, DocCollection coll, List createNodeList, List l) { + private static Map getNodesViaRules(ClusterState clusterState, String shard, int numberOfNodes, + CoreContainer cc, DocCollection coll, List createNodeList, List l) { ArrayList rules = new ArrayList<>(); for (Object o : l) rules.add(new Rule((Map) o)); Map> shardVsNodes = new LinkedHashMap<>(); @@ -214,11 +264,7 @@ public class Assign { shardVsNodes, nodesList, cc, clusterState).getNodeMappings(); - List repCounts = new ArrayList<>(); - for (String s : positions.values()) { - repCounts.add(new ReplicaCount(s)); - } - return repCounts; + return positions;// getReplicaCounts(positions); } private static HashMap getNodeNameVsShardCount(String collectionName, diff --git a/solr/core/src/java/org/apache/solr/cloud/OverseerCollectionMessageHandler.java b/solr/core/src/java/org/apache/solr/cloud/OverseerCollectionMessageHandler.java index 0d8e3abf0c4..2ff6285c399 100644 --- a/solr/core/src/java/org/apache/solr/cloud/OverseerCollectionMessageHandler.java +++ b/solr/core/src/java/org/apache/solr/cloud/OverseerCollectionMessageHandler.java @@ -87,6 +87,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import static java.util.Collections.singletonMap; +import static org.apache.solr.cloud.autoscaling.Policy.POLICY; import static org.apache.solr.common.cloud.DocCollection.SNITCH; import static org.apache.solr.common.cloud.ZkStateReader.BASE_URL_PROP; import static org.apache.solr.common.cloud.ZkStateReader.COLLECTION_PROP; @@ -144,6 +145,7 @@ public class OverseerCollectionMessageHandler implements OverseerMessageHandler ZkStateReader.MAX_SHARDS_PER_NODE, "1", ZkStateReader.AUTO_ADD_REPLICAS, "false", DocCollection.RULE, null, + POLICY, null, SNITCH, null)); private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); @@ -714,9 +716,9 @@ public class OverseerCollectionMessageHandler implements OverseerMessageHandler List shardNames, int numNrtReplicas, int numTlogReplicas, - int numPullReplicas) throws IOException, KeeperException, InterruptedException { + int numPullReplicas) throws KeeperException, InterruptedException { List rulesMap = (List) message.get("rule"); - String policyName = message.getStr("policy"); + String policyName = message.getStr(POLICY); Map autoSalingJson = zkStateReader.getZkClient().getJson(SOLR_AUTOSCALING_CONF_PATH, true); autoSalingJson = autoSalingJson == null ? Collections.EMPTY_MAP : autoSalingJson; @@ -746,23 +748,8 @@ public class OverseerCollectionMessageHandler implements OverseerMessageHandler } if (policyName != null || autoSalingJson.get(Policy.CLUSTER_POLICY) != null) { - String collName = message.getStr(COLLECTION_PROP, message.getStr(NAME)); - try (CloudSolrClient csc = new CloudSolrClient.Builder() - .withClusterStateProvider(new ZkClientClusterStateProvider(zkStateReader)) - .build()) { - SolrClientDataProvider clientDataProvider = new SolrClientDataProvider(csc); - Map> locations = PolicyHelper.getReplicaLocations(collName, - zkStateReader.getZkClient().getJson(SOLR_AUTOSCALING_CONF_PATH, true), - clientDataProvider, singletonMap(collName, policyName), shardNames, numNrtReplicas); - Map result = new HashMap<>(); - for (Map.Entry> e : locations.entrySet()) { - List value = e.getValue(); - for (int i = 0; i < value.size(); i++) { - result.put(new Position(e.getKey(), i, Replica.Type.NRT), value.get(i)); - } - } - return result; - } + return Assign.getPositionsUsingPolicy(message.getStr(COLLECTION_PROP, message.getStr(NAME)), + shardNames, numNrtReplicas, policyName, zkStateReader); } else { List rules = new ArrayList<>(); diff --git a/solr/core/src/test/org/apache/solr/cloud/autoscaling/AutoScalingHandlerTest.java b/solr/core/src/test/org/apache/solr/cloud/autoscaling/AutoScalingHandlerTest.java index a1913845cde..74fac04c482 100644 --- a/solr/core/src/test/org/apache/solr/cloud/autoscaling/AutoScalingHandlerTest.java +++ b/solr/core/src/test/org/apache/solr/cloud/autoscaling/AutoScalingHandlerTest.java @@ -22,7 +22,7 @@ import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.function.BiConsumer; +import java.util.function.Consumer; import org.apache.solr.client.solrj.SolrClient; import org.apache.solr.client.solrj.SolrRequest; @@ -44,7 +44,6 @@ import org.apache.solr.common.util.ContentStream; import org.apache.solr.common.util.ContentStreamBase; import org.apache.solr.common.util.NamedList; import org.apache.solr.common.util.Utils; -import org.apache.zookeeper.KeeperException; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; @@ -303,7 +302,7 @@ public class AutoScalingHandlerTest extends SolrCloudTestCase { } } - public void testCreateCollectionPolicy() throws Exception { + public void testCreateCollectionAddShardUsingPolicy() throws Exception { JettySolrRunner jetty = cluster.getRandomJetty(random()); int port = jetty.getLocalPort(); @@ -313,12 +312,18 @@ public class AutoScalingHandlerTest extends SolrCloudTestCase { Map json = cluster.getZkClient().getJson(ZkStateReader.SOLR_AUTOSCALING_CONF_PATH, true); assertEquals("full json:"+ Utils.toJSONString(json) , "#EACH", Utils.getObjectByPath(json, true, "/policies/c1[0]/shard")); - CollectionAdminRequest.createCollection("policiesTest",2, 1) + CollectionAdminRequest.createCollectionWithImplicitRouter("policiesTest", null, "s1,s2", 1) .setPolicy("c1") .process(cluster.getSolrClient()); DocCollection coll = getCollectionState("policiesTest"); + assertEquals("c1", coll.getPolicyName()); + assertEquals(2,coll.getReplicas().size()); coll.forEachReplica((s, replica) -> assertEquals(jetty.getNodeName(), replica.getNodeName())); + CollectionAdminRequest.createShard("policiesTest", "s3").process(cluster.getSolrClient()); + coll = getCollectionState("policiesTest"); + assertEquals(1, coll.getSlice("s3").getReplicas().size()); + coll.getSlice("s3").forEach(replica -> assertEquals(jetty.getNodeName(), replica.getNodeName())); } static SolrRequest createAutoScalingRequest(SolrRequest.METHOD m, String message) { diff --git a/solr/solrj/src/java/org/apache/solr/common/cloud/DocCollection.java b/solr/solrj/src/java/org/apache/solr/common/cloud/DocCollection.java index 6f663c5885e..5dc4ebb83e4 100644 --- a/solr/solrj/src/java/org/apache/solr/common/cloud/DocCollection.java +++ b/solr/solrj/src/java/org/apache/solr/common/cloud/DocCollection.java @@ -28,6 +28,7 @@ import java.util.Objects; import java.util.Set; import java.util.function.BiConsumer; +import org.apache.solr.cloud.autoscaling.Policy; import org.apache.solr.common.SolrException; import org.apache.solr.common.SolrException.ErrorCode; import org.noggit.JSONUtil; @@ -67,6 +68,7 @@ public class DocCollection extends ZkNodeProps implements Iterable { private final Integer numPullReplicas; private final Integer maxShardsPerNode; private final Boolean autoAddReplicas; + private final String policy; public DocCollection(String name, Map slices, Map props, DocRouter router) { this(name, slices, props, router, Integer.MAX_VALUE, ZkStateReader.CLUSTER_STATE); @@ -93,6 +95,7 @@ public class DocCollection extends ZkNodeProps implements Iterable { this.numPullReplicas = (Integer) verifyProp(props, PULL_REPLICAS); this.maxShardsPerNode = (Integer) verifyProp(props, MAX_SHARDS_PER_NODE); Boolean autoAddReplicas = (Boolean) verifyProp(props, AUTO_ADD_REPLICAS); + this.policy = (String) props.get(Policy.POLICY); this.autoAddReplicas = autoAddReplicas == null ? Boolean.FALSE : autoAddReplicas; verifyProp(props, RULE); @@ -368,4 +371,10 @@ public class DocCollection extends ZkNodeProps implements Iterable { return numPullReplicas; } + /** + * @return the policy associated with this collection if any + */ + public String getPolicyName() { + return policy; + } }