From debf8555c2fb0f309e3389aa196a788d8c4a5bb5 Mon Sep 17 00:00:00 2001
From: Duo Zhang <zhangduo@apache.org>
Date: Sat, 29 May 2021 10:54:44 +0800
Subject: [PATCH] HBASE-25947 Backport 'HBASE-25894 Improve the performance for
 region load and region count related cost functions' to branch-2.4 and
 branch-2.3

---
 .../master/balancer/DoubleArrayCost.java      | 100 ++++++
 .../balancer/StochasticLoadBalancer.java      | 337 +++++++++---------
 .../master/balancer/TestDoubleArrayCost.java  |  67 ++++
 .../TestStochasticBalancerJmxMetrics.java     |  19 +-
 .../balancer/TestStochasticLoadBalancer.java  |  28 --
 5 files changed, 342 insertions(+), 209 deletions(-)
 create mode 100644 hbase-server/src/main/java/org/apache/hadoop/hbase/master/balancer/DoubleArrayCost.java
 create mode 100644 hbase-server/src/test/java/org/apache/hadoop/hbase/master/balancer/TestDoubleArrayCost.java
 rename hbase-server/src/test/java/org/apache/hadoop/hbase/{ => master/balancer}/TestStochasticBalancerJmxMetrics.java (95%)

diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/master/balancer/DoubleArrayCost.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/master/balancer/DoubleArrayCost.java
new file mode 100644
index 000000000000..f370b8077d1b
--- /dev/null
+++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/master/balancer/DoubleArrayCost.java
@@ -0,0 +1,100 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.master.balancer;
+
+import java.util.function.Consumer;
+import org.apache.yetus.audience.InterfaceAudience;
+
+/**
+ * A helper class to compute a scaled cost using
+ * {@link org.apache.commons.math3.stat.descriptive.DescriptiveStatistics#DescriptiveStatistics()}.
+ * It assumes that this is a zero sum set of costs. It assumes that the worst case possible is all
+ * of the elements in one region server and the rest having 0.
+ */
+@InterfaceAudience.Private
+final class DoubleArrayCost {
+
+  private double[] costs;
+
+  // computeCost call is expensive so we use this flag to indicate whether we need to recalculate
+  // the cost by calling computeCost
+  private boolean costsChanged;
+
+  private double cost;
+
+  void prepare(int length) {
+    if (costs == null || costs.length != length) {
+      costs = new double[length];
+    }
+  }
+
+  void setCosts(Consumer<double[]> consumer) {
+    consumer.accept(costs);
+    costsChanged = true;
+  }
+
+  double cost() {
+    if (costsChanged) {
+      cost = computeCost(costs);
+      costsChanged = false;
+    }
+    return cost;
+  }
+
+  private static double computeCost(double[] stats) {
+    double totalCost = 0;
+    double total = getSum(stats);
+
+    double count = stats.length;
+    double mean = total / count;
+
+    // Compute max as if all region servers had 0 and one had the sum of all costs. This must be
+    // a zero sum cost for this to make sense.
+    double max = ((count - 1) * mean) + (total - mean);
+
+    // It's possible that there aren't enough regions to go around
+    double min;
+    if (count > total) {
+      min = ((count - total) * mean) + ((1 - mean) * total);
+    } else {
+      // Some will have 1 more than everything else.
+      int numHigh = (int) (total - (Math.floor(mean) * count));
+      int numLow = (int) (count - numHigh);
+
+      min = (numHigh * (Math.ceil(mean) - mean)) + (numLow * (mean - Math.floor(mean)));
+
+    }
+    min = Math.max(0, min);
+    for (int i = 0; i < stats.length; i++) {
+      double n = stats[i];
+      double diff = Math.abs(mean - n);
+      totalCost += diff;
+    }
+
+    double scaled = StochasticLoadBalancer.scale(min, max, totalCost);
+    return scaled;
+  }
+
+  private static double getSum(double[] stats) {
+    double total = 0;
+    for (double s : stats) {
+      total += s;
+    }
+    return total;
+  }
+}
\ No newline at end of file
diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/master/balancer/StochasticLoadBalancer.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/master/balancer/StochasticLoadBalancer.java
index 3e341d0f7718..26d8d43f5bef 100644
--- a/hbase-server/src/main/java/org/apache/hadoop/hbase/master/balancer/StochasticLoadBalancer.java
+++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/master/balancer/StochasticLoadBalancer.java
@@ -17,12 +17,14 @@
  */
 package org.apache.hadoop.hbase.master.balancer;
 
+import com.google.errorprone.annotations.RestrictedApi;
 import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Deque;
 import java.util.HashMap;
+import java.util.Iterator;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
@@ -145,7 +147,6 @@
   private boolean isBalancerRejectionRecording = false;
 
   private List<CandidateGenerator> candidateGenerators;
-  private CostFromRegionLoadFunction[] regionLoadFunctions;
   private List<CostFunction> costFunctions; // FindBugs: Wants this protected; IS2_INCONSISTENT_SYNC
 
   // to save and report costs to JMX
@@ -202,12 +203,6 @@ public synchronized void setConf(Configuration conf) {
       candidateGenerators.add(localityCandidateGenerator);
       candidateGenerators.add(new RegionReplicaRackCandidateGenerator());
     }
-    regionLoadFunctions = new CostFromRegionLoadFunction[] {
-      new ReadRequestCostFunction(conf),
-      new WriteRequestCostFunction(conf),
-      new MemStoreSizeCostFunction(conf),
-      new StoreFileCostFunction(conf)
-    };
     regionReplicaHostCostFunction = new RegionReplicaHostCostFunction(conf);
     regionReplicaRackCostFunction = new RegionReplicaRackCostFunction(conf);
 
@@ -220,10 +215,10 @@ public synchronized void setConf(Configuration conf) {
     addCostFunction(new TableSkewCostFunction(conf));
     addCostFunction(regionReplicaHostCostFunction);
     addCostFunction(regionReplicaRackCostFunction);
-    addCostFunction(regionLoadFunctions[0]);
-    addCostFunction(regionLoadFunctions[1]);
-    addCostFunction(regionLoadFunctions[2]);
-    addCostFunction(regionLoadFunctions[3]);
+    addCostFunction(new ReadRequestCostFunction(conf));
+    addCostFunction(new WriteRequestCostFunction(conf));
+    addCostFunction(new MemStoreSizeCostFunction(conf));
+    addCostFunction(new StoreFileCostFunction(conf));
     loadCustomCostFunctions(conf);
 
     curFunctionCosts = new double[costFunctions.size()];
@@ -290,9 +285,6 @@ protected void setSlop(Configuration conf) {
   public synchronized void setClusterMetrics(ClusterMetrics st) {
     super.setClusterMetrics(st);
     updateRegionLoad();
-    for(CostFromRegionLoadFunction cost : regionLoadFunctions) {
-      cost.setClusterMetrics(st);
-    }
 
     // update metrics size
     try {
@@ -599,12 +591,16 @@ private void addCostFunction(CostFunction costFunction) {
 
   private String functionCost() {
     StringBuilder builder = new StringBuilder();
-    for (CostFunction c:costFunctions) {
+    for (CostFunction c : costFunctions) {
       builder.append(c.getClass().getSimpleName());
       builder.append(" : (");
-      builder.append(c.getMultiplier());
-      builder.append(", ");
-      builder.append(c.cost());
+      if (c.isNeeded()) {
+        builder.append(c.getMultiplier());
+        builder.append(", ");
+        builder.append(c.cost());
+      } else {
+        builder.append("not needed");
+      }
       builder.append("); ");
     }
     return builder.toString();
@@ -613,11 +609,15 @@ private String functionCost() {
   private String totalCostsPerFunc() {
     StringBuilder builder = new StringBuilder();
     for (CostFunction c : costFunctions) {
-      if (c.getMultiplier() * c.cost() > 0.0) {
+      if (c.getMultiplier() <= 0 || !c.isNeeded()) {
+        continue;
+      }
+      double cost = c.getMultiplier() * c.cost();
+      if (cost > 0.0) {
         builder.append(" ");
         builder.append(c.getClass().getSimpleName());
         builder.append(" : ");
-        builder.append(c.getMultiplier() * c.cost());
+        builder.append(cost);
         builder.append(";");
       }
     }
@@ -679,29 +679,32 @@ private synchronized void updateRegionLoad() {
         loads.put(regionNameAsString, rLoads);
       });
     });
-
-    for(CostFromRegionLoadFunction cost : regionLoadFunctions) {
-      cost.setLoads(loads);
-    }
   }
 
-  protected void initCosts(Cluster cluster) {
+  @RestrictedApi(explanation = "Should only be called in tests", link = "",
+    allowedOnPath = ".*(/src/test/.*|StochasticLoadBalancer).java")
+  void initCosts(Cluster cluster) {
     for (CostFunction c:costFunctions) {
       c.init(cluster);
     }
   }
 
-  protected void updateCostsWithAction(Cluster cluster, Action action) {
+  @RestrictedApi(explanation = "Should only be called in tests", link = "",
+    allowedOnPath = ".*(/src/test/.*|StochasticLoadBalancer).java")
+  void updateCostsWithAction(Cluster cluster, Action action) {
     for (CostFunction c : costFunctions) {
-      c.postAction(action);
+      if (c.getMultiplier() > 0 && c.isNeeded()) {
+        c.postAction(action);
+      }
     }
   }
 
   /**
    * Get the names of the cost functions
    */
-  public String[] getCostFunctionNames() {
-    if (costFunctions == null) return null;
+  @RestrictedApi(explanation = "Should only be called in tests", link = "",
+    allowedOnPath = ".*(/src/test/.*|StochasticLoadBalancer).java")
+  String[] getCostFunctionNames() {
     String[] ret = new String[costFunctions.size()];
     for (int i = 0; i < costFunctions.size(); i++) {
       CostFunction c = costFunctions.get(i);
@@ -720,14 +723,16 @@ protected void updateCostsWithAction(Cluster cluster, Action action) {
    * @return a double of a cost associated with the proposed cluster state.  This cost is an
    *         aggregate of all individual cost functions.
    */
-  protected double computeCost(Cluster cluster, double previousCost) {
+  @RestrictedApi(explanation = "Should only be called in tests", link = "",
+    allowedOnPath = ".*(/src/test/.*|StochasticLoadBalancer).java")
+  double computeCost(Cluster cluster, double previousCost) {
     double total = 0;
 
     for (int i = 0; i < costFunctions.size(); i++) {
       CostFunction c = costFunctions.get(i);
       this.tempFunctionCosts[i] = 0.0;
 
-      if (c.getMultiplier() <= 0) {
+      if (c.getMultiplier() <= 0 || !c.isNeeded()) {
         continue;
       }
 
@@ -851,75 +856,24 @@ protected void regionMoved(int region, int oldServer, int newServer) {
     }
 
     protected abstract double cost();
+  }
 
-    @SuppressWarnings("checkstyle:linelength")
-    /**
-     * Function to compute a scaled cost using
-     * {@link org.apache.commons.math3.stat.descriptive.DescriptiveStatistics#DescriptiveStatistics()}.
-     * It assumes that this is a zero sum set of costs.  It assumes that the worst case
-     * possible is all of the elements in one region server and the rest having 0.
-     *
-     * @param stats the costs
-     * @return a scaled set of costs.
-     */
-    protected double costFromArray(double[] stats) {
-      double totalCost = 0;
-      double total = getSum(stats);
-
-      double count = stats.length;
-      double mean = total/count;
-
-      // Compute max as if all region servers had 0 and one had the sum of all costs.  This must be
-      // a zero sum cost for this to make sense.
-      double max = ((count - 1) * mean) + (total - mean);
-
-      // It's possible that there aren't enough regions to go around
-      double min;
-      if (count > total) {
-        min = ((count - total) * mean) + ((1 - mean) * total);
-      } else {
-        // Some will have 1 more than everything else.
-        int numHigh = (int) (total - (Math.floor(mean) * count));
-        int numLow = (int) (count - numHigh);
-
-        min = (numHigh * (Math.ceil(mean) - mean)) + (numLow * (mean - Math.floor(mean)));
-
-      }
-      min = Math.max(0, min);
-      for (int i=0; i<stats.length; i++) {
-        double n = stats[i];
-        double diff = Math.abs(mean - n);
-        totalCost += diff;
-      }
-
-      double scaled =  scale(min, max, totalCost);
-      return scaled;
+  /**
+   * Scale the value between 0 and 1.
+   * @param min Min value
+   * @param max The Max value
+   * @param value The value to be scaled.
+   * @return The scaled value.
+   */
+  static double scale(double min, double max, double value) {
+    if (max <= min || value <= min) {
+      return 0;
     }
-
-    private double getSum(double[] stats) {
-      double total = 0;
-      for(double s:stats) {
-        total += s;
-      }
-      return total;
+    if ((max - min) == 0) {
+      return 0;
     }
 
-    /**
-     * Scale the value between 0 and 1.
-     *
-     * @param min   Min value
-     * @param max   The Max value
-     * @param value The value to be scaled.
-     * @return The scaled value.
-     */
-    protected double scale(double min, double max, double value) {
-      if (max <= min || value <= min) {
-        return 0;
-      }
-      if ((max - min) == 0) return 0;
-
-      return Math.max(0d, Math.min(1d, (value - min) / (max - min)));
-    }
+    return Math.max(0d, Math.min(1d, (value - min) / (max - min)));
   }
 
   /**
@@ -938,28 +892,36 @@ protected double scale(double min, double max, double value) {
     private static final float DEFAULT_MAX_MOVE_PERCENT = 0.25f;
 
     private final float maxMovesPercent;
-    private final Configuration conf;
+    private final OffPeakHours offPeakHours;
+    private final float moveCost;
+    private final float moveCostOffPeak;
 
     MoveCostFunction(Configuration conf) {
       super(conf);
-      this.conf = conf;
       // What percent of the number of regions a single run of the balancer can move.
       maxMovesPercent = conf.getFloat(MAX_MOVES_PERCENT_KEY, DEFAULT_MAX_MOVE_PERCENT);
-
+      offPeakHours = OffPeakHours.getInstance(conf);
+      moveCost = conf.getFloat(MOVE_COST_KEY, DEFAULT_MOVE_COST);
+      moveCostOffPeak = conf.getFloat(MOVE_COST_OFFPEAK_KEY, DEFAULT_MOVE_COST_OFFPEAK);
       // Initialize the multiplier so that addCostFunction will add this cost function.
       // It may change during later evaluations, due to OffPeakHours.
-      this.setMultiplier(conf.getFloat(MOVE_COST_KEY, DEFAULT_MOVE_COST));
+      this.setMultiplier(moveCost);
     }
 
     @Override
-    protected double cost() {
+    void init(Cluster cluster) {
+      super.init(cluster);
       // Move cost multiplier should be the same cost or higher than the rest of the costs to ensure
       // that large benefits are need to overcome the cost of a move.
-      if (OffPeakHours.getInstance(conf).isOffPeakHour()) {
-        this.setMultiplier(conf.getFloat(MOVE_COST_OFFPEAK_KEY, DEFAULT_MOVE_COST_OFFPEAK));
+      if (offPeakHours.isOffPeakHour()) {
+        this.setMultiplier(moveCostOffPeak);
       } else {
-        this.setMultiplier(conf.getFloat(MOVE_COST_KEY, DEFAULT_MOVE_COST));
+        this.setMultiplier(moveCost);
       }
+    }
+
+    @Override
+    protected double cost() {
       // Try and size the max number of Moves, but always be prepared to move some.
       int maxMoves = Math.max((int) (cluster.numRegions * maxMovesPercent),
           DEFAULT_MAX_MOVES);
@@ -985,7 +947,7 @@ protected double cost() {
         "hbase.master.balancer.stochastic.regionCountCost";
     static final float DEFAULT_REGION_COUNT_SKEW_COST = 500;
 
-    private double[] stats = null;
+    private final DoubleArrayCost cost = new DoubleArrayCost();
 
     RegionCountSkewCostFunction(Configuration conf) {
       super(conf);
@@ -996,8 +958,14 @@ protected double cost() {
     @Override
     void init(Cluster cluster) {
       super.init(cluster);
+      cost.prepare(cluster.numServers);
+      cost.setCosts(costs -> {
+        for (int i = 0; i < cluster.numServers; i++) {
+          costs[i] = cluster.regionsPerServer[i].length;
+        }
+      });
       LOG.debug("{} sees a total of {} servers and {} regions.", getClass().getSimpleName(),
-          cluster.numServers, cluster.numRegions);
+        cluster.numServers, cluster.numRegions);
       if (LOG.isTraceEnabled()) {
         for (int i =0; i < cluster.numServers; i++) {
           LOG.trace("{} sees server '{}' has {} regions", getClass().getSimpleName(),
@@ -1008,13 +976,15 @@ void init(Cluster cluster) {
 
     @Override
     protected double cost() {
-      if (stats == null || stats.length != cluster.numServers) {
-        stats = new double[cluster.numServers];
-      }
-      for (int i =0; i < cluster.numServers; i++) {
-        stats[i] = cluster.regionsPerServer[i].length;
-      }
-      return costFromArray(stats);
+      return cost.cost();
+    }
+
+    @Override
+    protected void regionMoved(int region, int oldServer, int newServer) {
+      cost.setCosts(costs -> {
+        costs[oldServer] = cluster.regionsPerServer[oldServer].length;
+        costs[newServer] = cluster.regionsPerServer[newServer].length;
+      });
     }
   }
 
@@ -1027,7 +997,7 @@ protected double cost() {
         "hbase.master.balancer.stochastic.primaryRegionCountCost";
     private static final float DEFAULT_PRIMARY_REGION_COUNT_SKEW_COST = 500;
 
-    private double[] stats = null;
+    private final DoubleArrayCost cost = new DoubleArrayCost();
 
     PrimaryRegionCountSkewCostFunction(Configuration conf) {
       super(conf);
@@ -1036,30 +1006,45 @@ protected double cost() {
         DEFAULT_PRIMARY_REGION_COUNT_SKEW_COST));
     }
 
+    private double computeCostForRegionServer(int regionServerIndex) {
+      int cost = 0;
+      for (int regionIdx : cluster.regionsPerServer[regionServerIndex]) {
+        if (regionIdx == cluster.regionIndexToPrimaryIndex[regionIdx]) {
+          cost++;
+        }
+      }
+      return cost;
+    }
+
+    @Override
+    void init(Cluster cluster) {
+      super.init(cluster);
+      if (!isNeeded()) {
+        return;
+      }
+      cost.prepare(cluster.numServers);
+      cost.setCosts(costs -> {
+        for (int i = 0; i < costs.length; i++) {
+          costs[i] = computeCostForRegionServer(i);
+        }
+      });
+    }
+
     @Override
     boolean isNeeded() {
       return cluster.hasRegionReplicas;
     }
 
+    @Override
+    protected void regionMoved(int region, int oldServer, int newServer) {
+      cost.setCosts(costs -> {
+        costs[oldServer] = computeCostForRegionServer(oldServer);
+        costs[newServer] = computeCostForRegionServer(newServer);
+      });
+    }
     @Override
     protected double cost() {
-      if (!cluster.hasRegionReplicas) {
-        return 0;
-      }
-      if (stats == null || stats.length != cluster.numServers) {
-        stats = new double[cluster.numServers];
-      }
-
-      for (int i = 0; i < cluster.numServers; i++) {
-        stats[i] = 0;
-        for (int regionIdx : cluster.regionsPerServer[i]) {
-          if (regionIdx == cluster.regionIndexToPrimaryIndex[regionIdx]) {
-            stats[i]++;
-          }
-        }
-      }
-
-      return costFromArray(stats);
+      return cost.cost();
     }
   }
 
@@ -1194,51 +1179,51 @@ int regionIndexToEntityIndex(int region) {
    */
   abstract static class CostFromRegionLoadFunction extends CostFunction {
 
-    private ClusterMetrics clusterStatus = null;
-    private Map<String, Deque<BalancerRegionLoad>> loads = null;
-    private double[] stats = null;
+    private final DoubleArrayCost cost = new DoubleArrayCost();
+
     CostFromRegionLoadFunction(Configuration conf) {
       super(conf);
     }
 
-    void setClusterMetrics(ClusterMetrics status) {
-      this.clusterStatus = status;
-    }
-
-    void setLoads(Map<String, Deque<BalancerRegionLoad>> l) {
-      this.loads = l;
-    }
+    private double computeCostForRegionServer(int regionServerIndex) {
+      // Cost this server has from RegionLoad
+      double cost = 0;
 
-    @Override
-    protected double cost() {
-      if (clusterStatus == null || loads == null) {
-        return 0;
-      }
+      // for every region on this server get the rl
+      for (int regionIndex : cluster.regionsPerServer[regionServerIndex]) {
+        Collection<BalancerRegionLoad> regionLoadList = cluster.regionLoads[regionIndex];
 
-      if (stats == null || stats.length != cluster.numServers) {
-        stats = new double[cluster.numServers];
+        // Now if we found a region load get the type of cost that was requested.
+        if (regionLoadList != null) {
+          cost += getRegionLoadCost(regionLoadList);
+        }
       }
+      return cost;
+    }
 
-      for (int i =0; i < stats.length; i++) {
-        //Cost this server has from RegionLoad
-        long cost = 0;
-
-        // for every region on this server get the rl
-        for(int regionIndex:cluster.regionsPerServer[i]) {
-          Collection<BalancerRegionLoad> regionLoadList =  cluster.regionLoads[regionIndex];
-
-          // Now if we found a region load get the type of cost that was requested.
-          if (regionLoadList != null) {
-            cost = (long) (cost + getRegionLoadCost(regionLoadList));
-          }
+    @Override
+    void init(Cluster cluster) {
+      super.init(cluster);
+      cost.prepare(cluster.numServers);
+      cost.setCosts(costs -> {
+        for (int i = 0; i < costs.length; i++) {
+          costs[i] = computeCostForRegionServer(i);
         }
+      });
+    }
 
-        // Add the total cost to the stats.
-        stats[i] = cost;
-      }
+    @Override
+    protected void regionMoved(int region, int oldServer, int newServer) {
+      // recompute the stat for the given two region servers
+      cost.setCosts(costs -> {
+        costs[oldServer] = computeCostForRegionServer(oldServer);
+        costs[newServer] = computeCostForRegionServer(newServer);
+      });
+    }
 
-      // Now return the scaled cost from data held in the stats object.
-      return costFromArray(stats);
+    @Override
+    protected final double cost() {
+      return cost.cost();
     }
 
     protected double getRegionLoadCost(Collection<BalancerRegionLoad> regionLoadList) {
@@ -1265,18 +1250,20 @@ protected double getRegionLoadCost(Collection<BalancerRegionLoad> regionLoadList
 
     @Override
     protected double getRegionLoadCost(Collection<BalancerRegionLoad> regionLoadList) {
+      Iterator<BalancerRegionLoad> iter = regionLoadList.iterator();
+      if (!iter.hasNext()) {
+        return 0;
+      }
+      double previous = getCostFromRl(iter.next());
+      if (!iter.hasNext()) {
+        return 0;
+      }
       double cost = 0;
-      double previous = 0;
-      boolean isFirst = true;
-      for (BalancerRegionLoad rl : regionLoadList) {
-        double current = getCostFromRl(rl);
-        if (isFirst) {
-          isFirst = false;
-        } else {
-          cost += current - previous;
-        }
+      do {
+        double current = getCostFromRl(iter.next());
+        cost += current - previous;
         previous = current;
-      }
+      } while (iter.hasNext());
       return Math.max(0, cost / (regionLoadList.size() - 1));
     }
   }
diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/master/balancer/TestDoubleArrayCost.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/master/balancer/TestDoubleArrayCost.java
new file mode 100644
index 000000000000..8dd1e4973b68
--- /dev/null
+++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/master/balancer/TestDoubleArrayCost.java
@@ -0,0 +1,67 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.master.balancer;
+
+import static org.junit.Assert.assertEquals;
+
+import org.apache.hadoop.hbase.HBaseClassTestRule;
+import org.apache.hadoop.hbase.testclassification.MasterTests;
+import org.apache.hadoop.hbase.testclassification.SmallTests;
+import org.junit.ClassRule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+
+@Category({ MasterTests.class, SmallTests.class })
+public class TestDoubleArrayCost {
+
+  @ClassRule
+  public static final HBaseClassTestRule CLASS_RULE =
+    HBaseClassTestRule.forClass(TestDoubleArrayCost.class);
+
+  @Test
+  public void testComputeCost() {
+    DoubleArrayCost cost = new DoubleArrayCost();
+
+    cost.prepare(100);
+    cost.setCosts(costs -> {
+      for (int i = 0; i < 100; i++) {
+        costs[i] = 10;
+      }
+    });
+    assertEquals(0, cost.cost(), 0.01);
+
+    cost.prepare(101);
+    cost.setCosts(costs -> {
+      for (int i = 0; i < 100; i++) {
+        costs[i] = 0;
+      }
+      costs[100] = 100;
+    });
+    assertEquals(1, cost.cost(), 0.01);
+
+    cost.prepare(200);
+    cost.setCosts(costs -> {
+      for (int i = 0; i < 100; i++) {
+        costs[i] = 0;
+        costs[i + 100] = 100;
+      }
+      costs[100] = 100;
+    });
+    assertEquals(0.5, cost.cost(), 0.01);
+  }
+}
\ No newline at end of file
diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/TestStochasticBalancerJmxMetrics.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/master/balancer/TestStochasticBalancerJmxMetrics.java
similarity index 95%
rename from hbase-server/src/test/java/org/apache/hadoop/hbase/TestStochasticBalancerJmxMetrics.java
rename to hbase-server/src/test/java/org/apache/hadoop/hbase/master/balancer/TestStochasticBalancerJmxMetrics.java
index 16d2c4d7c5ba..ab1c76c2e9dd 100644
--- a/hbase-server/src/test/java/org/apache/hadoop/hbase/TestStochasticBalancerJmxMetrics.java
+++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/master/balancer/TestStochasticBalancerJmxMetrics.java
@@ -15,7 +15,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.hadoop.hbase;
+package org.apache.hadoop.hbase.master.balancer;
 
 import static org.junit.Assert.assertTrue;
 
@@ -35,10 +35,14 @@
 import javax.management.remote.JMXConnector;
 import javax.management.remote.JMXConnectorFactory;
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.hbase.HBaseClassTestRule;
+import org.apache.hadoop.hbase.HBaseTestingUtility;
+import org.apache.hadoop.hbase.HConstants;
+import org.apache.hadoop.hbase.JMXListener;
+import org.apache.hadoop.hbase.ServerName;
+import org.apache.hadoop.hbase.TableName;
 import org.apache.hadoop.hbase.client.RegionInfo;
 import org.apache.hadoop.hbase.coprocessor.CoprocessorHost;
-import org.apache.hadoop.hbase.master.balancer.BalancerTestBase;
-import org.apache.hadoop.hbase.master.balancer.StochasticLoadBalancer;
 import org.apache.hadoop.hbase.testclassification.MediumTests;
 import org.apache.hadoop.hbase.testclassification.MiscTests;
 import org.apache.hadoop.hbase.util.Threads;
@@ -199,7 +203,9 @@ public void testJmxMetrics_PerTableMode() throws Exception {
     final int count = 0;
     for (int i = 0; i < 10; i++) {
       Set<String> metrics = readJmxMetrics();
-      if (metrics != null) return metrics;
+      if (metrics != null) {
+        return metrics;
+      }
       LOG.warn("Failed to get jmxmetrics... sleeping, retrying; " + i + " of " + count + " times");
       Threads.sleep(1000);
     }
@@ -208,7 +214,6 @@ public void testJmxMetrics_PerTableMode() throws Exception {
 
   /**
    * Read the attributes from Hadoop->HBase->Master->Balancer in JMX
-   * @throws IOException
    */
   private Set<String> readJmxMetrics() throws IOException {
     JMXConnector connector = null;
@@ -273,7 +278,9 @@ public void testJmxMetrics_PerTableMode() throws Exception {
   }
 
   private static void printMetrics(Set<String> metrics, String info) {
-    if (null != info) LOG.info("++++ ------ " + info + " ------");
+    if (null != info) {
+      LOG.info("++++ ------ " + info + " ------");
+    }
 
     LOG.info("++++ metrics count = " + metrics.size());
     for (String str : metrics) {
diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/master/balancer/TestStochasticLoadBalancer.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/master/balancer/TestStochasticLoadBalancer.java
index b97679f0470f..ea65f96eaf51 100644
--- a/hbase-server/src/test/java/org/apache/hadoop/hbase/master/balancer/TestStochasticLoadBalancer.java
+++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/master/balancer/TestStochasticLoadBalancer.java
@@ -358,34 +358,6 @@ public void testRegionLoadCost() {
     assertEquals(2.5, result, 0.01);
   }
 
-  @Test
-  public void testCostFromArray() {
-    Configuration conf = HBaseConfiguration.create();
-    StochasticLoadBalancer.CostFromRegionLoadFunction
-        costFunction = new StochasticLoadBalancer.MemStoreSizeCostFunction(conf);
-    costFunction.init(mockCluster(new int[]{0, 0, 0, 0, 1}));
-
-    double[] statOne = new double[100];
-    for (int i =0; i < 100; i++) {
-      statOne[i] = 10;
-    }
-    assertEquals(0, costFunction.costFromArray(statOne), 0.01);
-
-    double[] statTwo= new double[101];
-    for (int i =0; i < 100; i++) {
-      statTwo[i] = 0;
-    }
-    statTwo[100] = 100;
-    assertEquals(1, costFunction.costFromArray(statTwo), 0.01);
-
-    double[] statThree = new double[200];
-    for (int i =0; i < 100; i++) {
-      statThree[i] = (0);
-      statThree[i+100] = 100;
-    }
-    assertEquals(0.5, costFunction.costFromArray(statThree), 0.01);
-  }
-
   @Test
   public void testLosingRs() throws Exception {
     int numNodes = 3;