diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/impl/CloudSolrServer.java b/solr/solrj/src/java/org/apache/solr/client/solrj/impl/CloudSolrServer.java index 62c3e37412e..1f42d2c1e69 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/impl/CloudSolrServer.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/impl/CloudSolrServer.java @@ -19,7 +19,6 @@ package org.apache.solr.client.solrj.impl; import java.io.IOException; import java.net.MalformedURLException; -import java.net.URLDecoder; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -61,6 +60,7 @@ import org.apache.solr.common.cloud.ZkNodeProps; import org.apache.solr.common.cloud.ZkStateReader; import org.apache.solr.common.cloud.ZooKeeperException; import org.apache.solr.common.params.ModifiableSolrParams; +import org.apache.solr.common.params.ShardParams; import org.apache.solr.common.params.SolrParams; import org.apache.solr.common.params.UpdateParams; import org.apache.solr.common.util.NamedList; @@ -542,6 +542,11 @@ public class CloudSolrServer extends SolrServer { "Could not find collection: " + collection); } + String shardKeys = reqParams.get(ShardParams._ROUTE_); + if(shardKeys == null) { + shardKeys = reqParams.get(ShardParams.SHARD_KEYS); // deprecated + } + // TODO: not a big deal because of the caching, but we could avoid looking // at every shard // when getting leaders if we tweaked some things @@ -551,16 +556,12 @@ public class CloudSolrServer extends SolrServer { // add it to the Map of slices. Map slices = new HashMap<>(); for (String collectionName : collectionsList) { - Collection colSlices = clusterState - .getActiveSlices(collectionName); - if (colSlices == null) { - throw new SolrServerException("Could not find collection:" - + collectionName); - } - ClientUtils.addSlices(slices, collectionName, colSlices, true); + DocCollection col = clusterState.getCollection(collectionName); + Collection routeSlices = col.getRouter().getSearchSlices(shardKeys, reqParams , col); + ClientUtils.addSlices(slices, collectionName, routeSlices, true); } Set liveNodes = clusterState.getLiveNodes(); - + List leaderUrlList = null; List urlList = null; List replicasList = null; diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/impl/CloudSolrServerTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/impl/CloudSolrServerTest.java index eb29cb10213..075ab51ab80 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/impl/CloudSolrServerTest.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/impl/CloudSolrServerTest.java @@ -18,13 +18,20 @@ package org.apache.solr.client.solrj.impl; */ import java.io.File; +import java.io.IOException; import java.net.MalformedURLException; +import java.util.Collection; import java.util.Iterator; +import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.TimeoutException; -import org.apache.lucene.util.LuceneTestCase.BadApple; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; import org.apache.lucene.util.LuceneTestCase.Slow; +import org.apache.solr.client.solrj.SolrServerException; import org.apache.solr.client.solrj.request.AbstractUpdateRequest; import org.apache.solr.client.solrj.request.QueryRequest; import org.apache.solr.client.solrj.request.UpdateRequest; @@ -33,7 +40,15 @@ import org.apache.solr.cloud.AbstractFullDistribZkTestBase; import org.apache.solr.cloud.AbstractZkTestCase; import org.apache.solr.common.SolrDocumentList; import org.apache.solr.common.SolrInputDocument; +import org.apache.solr.common.cloud.ClusterState; +import org.apache.solr.common.cloud.DocCollection; +import org.apache.solr.common.cloud.DocRouter; +import org.apache.solr.common.cloud.Replica; +import org.apache.solr.common.cloud.Slice; +import org.apache.solr.common.cloud.ZkStateReader; +import org.apache.solr.common.params.CommonParams; import org.apache.solr.common.params.ModifiableSolrParams; +import org.apache.solr.common.params.ShardParams; import org.apache.solr.common.util.NamedList; import org.junit.After; import org.junit.AfterClass; @@ -189,10 +204,105 @@ public class CloudSolrServerTest extends AbstractFullDistribZkTestBase { } finally { threadedClient.shutdown(); } - + + // Test that queries with _route_ params are routed by the client + + // Track request counts on each node before query calls + ClusterState clusterState = cloudClient.getZkStateReader().getClusterState(); + DocCollection col = clusterState.getCollection(DEFAULT_COLLECTION); + Map requestCountsMap = Maps.newHashMap(); + for (Slice slice : col.getSlices()) { + for (Replica replica : slice.getReplicas()) { + String baseURL = (String) replica.get(ZkStateReader.BASE_URL_PROP); + requestCountsMap.put(baseURL, getNumRequests(new HttpSolrServer(baseURL))); + } + } + + // Collect the base URLs of the replicas of shard that's expected to be hit + DocRouter router = col.getRouter(); + Collection expectedSlices = router.getSearchSlicesSingle("0", null, col); + Set expectedBaseURLs = Sets.newHashSet(); + for (Slice expectedSlice : expectedSlices) { + for (Replica replica : expectedSlice.getReplicas()) { + String baseURL = (String) replica.get(ZkStateReader.BASE_URL_PROP); + expectedBaseURLs.add(baseURL); + } + } + + assertTrue("expected urls is not fewer than all urls! expected=" + expectedBaseURLs + + "; all=" + requestCountsMap.keySet(), + expectedBaseURLs.size() < requestCountsMap.size()); + + // Calculate a number of shard keys that route to the same shard. + List sameShardRoutes = Lists.newArrayList(); + sameShardRoutes.add("0"); + for (int i = 1; i < 1000; i++) { + String shardKey = Integer.toString(i); + Collection slices = router.getSearchSlicesSingle(shardKey, null, col); + if (expectedSlices.equals(slices)) { + sameShardRoutes.add(shardKey); + } + } + + assertTrue(sameShardRoutes.size() > 1); + + // Do 1000 queries with _route_ parameter to the same shard + for (int i = 0; i < 1000; i++) { + ModifiableSolrParams solrParams = new ModifiableSolrParams(); + solrParams.set(CommonParams.Q, "*:*"); + solrParams.set(ShardParams._ROUTE_, sameShardRoutes.get(random().nextInt(sameShardRoutes.size()))); + cloudClient.query(solrParams); + } + + // Request counts increase from expected nodes should aggregate to 1000, while there should be + // no increase in unexpected nodes. + int increaseFromExpectedUrls = 0; + int increaseFromUnexpectedUrls = 0; + Map numRequestsToUnexpectedUrls = Maps.newHashMap(); + for (Slice slice : col.getSlices()) { + for (Replica replica : slice.getReplicas()) { + String baseURL = (String) replica.get(ZkStateReader.BASE_URL_PROP); + + Long prevNumRequests = requestCountsMap.get(baseURL); + Long curNumRequests = getNumRequests(new HttpSolrServer(baseURL)); + + long delta = curNumRequests - prevNumRequests; + if (expectedBaseURLs.contains(baseURL)) { + increaseFromExpectedUrls += delta; + } else { + increaseFromUnexpectedUrls += delta; + numRequestsToUnexpectedUrls.put(baseURL, delta); + } + } + } + + assertEquals("Unexpected number of requests to expected URLs", 1000, increaseFromExpectedUrls); + assertEquals("Unexpected number of requests to unexpected URLs: " + numRequestsToUnexpectedUrls, + 0, increaseFromUnexpectedUrls); + del("*:*"); commit(); } + + private Long getNumRequests(HttpSolrServer solrServer) throws + SolrServerException, IOException { + HttpSolrServer server = new HttpSolrServer(solrServer.getBaseURL()); + server.setConnectionTimeout(15000); + server.setSoTimeout(60000); + ModifiableSolrParams params = new ModifiableSolrParams(); + params.set("qt", "/admin/mbeans"); + params.set("stats", "true"); + params.set("key", "org.apache.solr.handler.StandardRequestHandler"); + params.set("cat", "QUERYHANDLER"); + // use generic request to avoid extra processing of queries + QueryRequest req = new QueryRequest(params); + NamedList resp = server.request(req); + NamedList mbeans = (NamedList) resp.get("solr-mbeans"); + NamedList queryHandler = (NamedList) mbeans.get("QUERYHANDLER"); + NamedList select = (NamedList) queryHandler.get("org.apache.solr.handler.StandardRequestHandler"); + NamedList stats = (NamedList) select.get("stats"); + return (Long) stats.get("requests"); + } @Override protected void indexr(Object... fields) throws Exception {